Commit 2676ad00 authored by laibao's avatar laibao
Browse files

v1: add SnapKV Triton KV compression

Introduce v1 KV compression modules (budget + SnapKV Triton kernel) and integrate with scheduler/cache managers.
parent 155c8a13
...@@ -141,6 +141,34 @@ if TYPE_CHECKING: ...@@ -141,6 +141,34 @@ if TYPE_CHECKING:
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
# KV compression (token-shared) for v1 paged attention.
# When enabled, vLLM decouples logical positions from KV cache positions
# and keeps only a subset of prompt tokens in KV cache during prefill.
VLLM_ENABLE_KV_COMPRESSION: bool = False
# KV compression policy for selecting which prompt KV entries to retain.
# Currently only "topk" is supported.
VLLM_KV_COMPRESSION_POLICY: str = "topk"
# Target prompt KV budget for token-shared compression.
# If PROMPT_BUDGET >= 0, it takes precedence over PROMPT_RATIO.
# The budget/ratio applies to non-protected prompt tokens only.
VLLM_KV_COMPRESSION_PROMPT_RATIO: float = 1.0
VLLM_KV_COMPRESSION_PROMPT_BUDGET: int = -1
VLLM_KV_COMPRESSION_PROTECTED_PREFIX: int = 0
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX: int = 0
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN: bool = True
# SnapKV-like scoring wi这个ndow used by the "topk" policy.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW: int = 32
# Use Triton SnapKV scoring on ROCm (experimental). Set to 0 to force the
# PyTorch reference implementation.
VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM: bool = True
# If set, compute token-shared Top-K selection per attention layer instead
# of sharing a single selection across all layers in a forward pass.
VLLM_KV_COMPRESSION_TOPK_PER_LAYER: bool = False
# Run KV compaction writeback (reshape_and_cache_*) on a separate CUDA
# stream to overlap with compute (experimental).
VLLM_KV_COMPRESSION_ASYNC_WRITEBACK: bool = False
# Free unused tail KV cache blocks after prompt compaction (experimental).
VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
# add envs # add envs
...@@ -1054,6 +1082,50 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1054,6 +1082,50 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN": "VLLM_USE_TRITON_PREFIX_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")), ("true", "1")),
# Enable token-shared KV compression for v1 paged attention (experimental).
# This feature currently targets long-prompt prefill memory reduction.
"VLLM_ENABLE_KV_COMPRESSION":
lambda: bool(int(os.getenv("VLLM_ENABLE_KV_COMPRESSION", "0"))),
# KV compression policy ("topk").
"VLLM_KV_COMPRESSION_POLICY":
lambda: os.getenv("VLLM_KV_COMPRESSION_POLICY", "topk").lower(),
# Target fraction of non-protected prompt tokens to keep in KV cache.
"VLLM_KV_COMPRESSION_PROMPT_RATIO":
lambda: float(os.getenv("VLLM_KV_COMPRESSION_PROMPT_RATIO", "1.0")),
# Target number of non-protected prompt tokens to keep in KV cache.
# If >= 0, this takes precedence over VLLM_KV_COMPRESSION_PROMPT_RATIO.
"VLLM_KV_COMPRESSION_PROMPT_BUDGET":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROMPT_BUDGET", "-1")),
# Always keep the first N prompt tokens in KV cache (e.g. BOS/system).
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROTECTED_PREFIX", "0")),
# Always keep the last N prompt tokens in KV cache.
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROTECTED_SUFFIX", "0")),
# Always keep the last prompt token (prompt_len - 1) when it is scheduled.
"VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN":
lambda: bool(int(os.getenv("VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN", "1"))),
# SnapKV-like scoring window size for the "topk" policy.
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_SNAPKV_WINDOW", "32")),
# Enable Triton SnapKV scoring on ROCm (experimental).
"VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM", "1"))),
# If set, compute token-shared Top-K selection per attention layer instead
# of sharing one selection across layers in a forward pass.
"VLLM_KV_COMPRESSION_TOPK_PER_LAYER":
lambda: bool(int(os.getenv("VLLM_KV_COMPRESSION_TOPK_PER_LAYER", "0"))),
# If set, run KV compaction writeback on a separate CUDA stream to overlap
# cache writes with compute (experimental).
"VLLM_KV_COMPRESSION_ASYNC_WRITEBACK":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_ASYNC_WRITEBACK", "0"))),
# If set, free unused tail KV cache blocks after prompt compaction.
"VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS", "0"))),
# If set, vLLM will use optimized MLA attention optimizations. # If set, vLLM will use optimized MLA attention optimizations.
"VLLM_USE_TRITON_OPT_MLA": "VLLM_USE_TRITON_OPT_MLA":
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
from typing import Optional, Union
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import triton
import triton.language as tl
if HAS_TRITON:
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk},
num_warps=num_warps,
num_stages=num_stages,
) for bq in [32, 64] for bk in [32, 64] for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
)
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
):
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0)
s = tl.dot(q_rows, k_blk.T) * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
@triton.jit
def _lse_and_store_logits_kernel_rocm_safe(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""ROCm-safe variant of `_lse_and_store_logits_kernel`.
On some ROCm + Triton (HIP) stacks we have observed memory corruption
from the tl.dot-based implementation. This variant avoids `tl.dot` and
instead computes dot-products via explicit outer-product accumulation.
"""
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
# Accumulate s = Q @ K^T in fp32 via outer products.
s = tl.zeros([BLOCK_Q, BLOCK_K], dtype=tl.float32)
for ds in tl.static_range(0, D, BLOCK_D):
offs_d = ds + tl.arange(0, BLOCK_D)
dmask = offs_d < D
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_chunk = tl.load(
q_ptrs,
mask=row_mask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BQ, BD]
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_chunk = tl.load(
k_ptrs,
mask=kmask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BK, BD]
s += tl.sum(q_chunk[:, None, :] * k_chunk[None, :, :], axis=2)
s = s * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
m,
mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
S,
mask=row_mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64] for bk in [32, 64, 128]
],
key=["HK", "HQ"],
)
@triton.jit
def _scores_from_logits_kernel(
cu_k,
w_b,
in_m,
in_S,
LOGITS,
OUT,
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
DO_POOL: tl.constexpr,
KPOOL: tl.constexpr,
PROTECT_LAST: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
scores = tl.zeros([BLOCK_K], dtype=tl.float32)
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
m_ptr = (in_m + b * STRIDE_M_B + hk * STRIDE_M_H +
row0 * STRIDE_M_R)
S_ptr = (in_S + b * STRIDE_S_B + hk * STRIDE_S_H +
row0 * STRIDE_S_R)
m = tl.load(m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"))
S = tl.load(S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
mask=rmask,
other=0.0)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] *
STRIDE_LG_R)
s_T = tl.load(log_ptrs,
mask=kmask[:, None] & rmask[None, :],
other=-float("inf"))
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
scores += tl.sum(probs_T, 1)
if DO_POOL and (KPOOL > 1):
i = tl.arange(0, BLOCK_K)[:, None]
j = tl.arange(0, BLOCK_K)[None, :]
band = (j <= i) & ((i - j) < KPOOL)
band = band & kmask[None, :]
sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1)
denom = tl.sum(band, 1).to(tl.float32)
denom = tl.where(denom > 0, denom, 1.0)
scores = sums / denom
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, scores, mask=kmask)
if PROTECT_LAST:
pad_beg = k_eff_end
pad_end = k_end
if pad_end > pad_beg:
for ks in tl.range(pad_beg, pad_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < pad_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs,
tl.full([BLOCK_K], float("inf"), dtype=tl.float32),
mask=kmask)
@triton.autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT,
cu_k,
w_b,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr,
EPS: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1] int32
cu_seqlens_k: torch.Tensor, # [B+1] int32
w: Union[int, torch.Tensor], # [B] int32 or scalar
sm_scale: Optional[float] = None,
*,
pool: bool = True,
kpool: int = 5,
protect_last: bool = True,
normalize: bool = False,
) -> torch.Tensor:
"""SnapKV query-aware key scores (Triton), returns [N_k, Hk] float32."""
if not HAS_TRITON:
raise RuntimeError("Triton is not available.")
if q.device.type != "cuda" or k.device.type != "cuda":
raise RuntimeError("Triton SnapKV requires CUDA/ROCm tensors.")
if q.ndim != 3 or k.ndim != 3:
raise ValueError("q and k must be 3D tensors.")
if q.stride(-1) != 1 or k.stride(-1) != 1:
raise ValueError("Last dim must be contiguous for Triton SnapKV.")
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
if D != Dk:
raise ValueError("q and k must have the same head size.")
if (Hq % Hk) != 0:
raise ValueError("Hq must be a multiple of Hk.")
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = int(cu_seqlens_q.numel() - 1)
if B != int(cu_seqlens_k.numel() - 1):
raise ValueError("cu_seqlens_q and cu_seqlens_k must match.")
G = Hq // Hk
if isinstance(w, int):
max_w = int(w)
w = torch.full((B, ),
fill_value=max_w,
device=device,
dtype=torch.int32)
else:
if w.numel() != B:
raise ValueError("w must have shape [B].")
w = w.to(device=device, dtype=torch.int32)
max_w = int(w.max().item())
rows_max = max_w * G
if rows_max <= 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
if kpool < 1:
raise ValueError("kpool must be >= 1.")
out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
S_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
logits_buf = torch.empty((N_k, Hk, rows_max),
dtype=torch.float32,
device=device)
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(meta):
return B, Hk, triton.cdiv(rows_max, meta["BLOCK_Q"])
cu_q = cu_seqlens_q.to(device=device, dtype=torch.int32)
cu_k = cu_seqlens_k.to(device=device, dtype=torch.int32)
# NOTE: On ROCm/HIP, we prefer a dot-free kernel variant to avoid known
# correctness issues (silent memory corruption) observed with the tl.dot
# implementation on some stacks.
is_rocm = getattr(torch.version, "hip", None) is not None
if is_rocm:
_lse_and_store_logits_kernel_rocm_safe[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
BLOCK_Q=32,
BLOCK_K=32,
BLOCK_D=16,
num_warps=4,
num_stages=1,
)
else:
_lse_and_store_logits_kernel[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=rows_max,
)
_scores_from_logits_kernel[(B, Hk)](
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
out,
QUERY_GROUP_SIZE=G,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_OUT_NK=STRIDE_OUT_NK,
STRIDE_OUT_HK=STRIDE_OUT_HK,
DO_POOL=pool,
KPOOL=kpool,
PROTECT_LAST=protect_last,
)
if normalize:
_zscore_per_batch_epilogue[(B, )](
out,
cu_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
return out
...@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC): ...@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers: for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens) manager.remove_skipped_blocks(request_id, num_computed_tokens)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
Returns True if any blocks were freed.
"""
truncated = False
for manager in self.single_type_managers:
truncated = manager.truncate_to_num_tokens(request_id,
num_tokens) or truncated
return truncated
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
""" """
Get the blocks for the request. Get the blocks for the request.
......
...@@ -7,6 +7,8 @@ from typing import Optional ...@@ -7,6 +7,8 @@ from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger from vllm.logger import init_logger
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import sha256 from vllm.utils import sha256
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
...@@ -251,9 +253,17 @@ class KVCacheManager: ...@@ -251,9 +253,17 @@ class KVCacheManager:
# the new prefix caching hits # the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens + num_computed_tokens = (request.num_computed_tokens +
num_new_computed_tokens) num_new_computed_tokens)
num_tokens_need_slot = min( if envs.VLLM_ENABLE_KV_COMPRESSION and not current_platform.is_tpu():
num_computed_tokens + num_new_tokens + num_lookahead_tokens, # KV compression decouples logical positions from KV cache
self.max_model_len) # positions. Allocate based on the KV cache length (plus the tokens
# scheduled for this step, which are temporarily written to cache).
num_tokens_need_slot = min(
request.num_kv_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
else:
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id, request_id=request.request_id,
...@@ -385,6 +395,14 @@ class KVCacheManager: ...@@ -385,6 +395,14 @@ class KVCacheManager:
return KVCacheBlocks( return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids() self.coordinator.get_blocks(request_id)).get_block_ids()
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort operation that may free blocks back to the pool.
Returns True if any blocks were freed.
"""
return self.coordinator.truncate_to_num_tokens(request_id, num_tokens)
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled.""" """Cache the blocks for the request, if enabled."""
if self.enable_caching: if self.enable_caching:
......
...@@ -31,6 +31,7 @@ class NewRequestData: ...@@ -31,6 +31,7 @@ class NewRequestData:
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
num_kv_tokens: int
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
@classmethod @classmethod
...@@ -49,6 +50,7 @@ class NewRequestData: ...@@ -49,6 +50,7 @@ class NewRequestData:
pooling_params=request.pooling_params, pooling_params=request.pooling_params,
block_ids=block_ids, block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens, num_computed_tokens=request.num_computed_tokens,
num_kv_tokens=request.num_kv_tokens,
lora_request=request.lora_request, lora_request=request.lora_request,
) )
...@@ -62,6 +64,7 @@ class NewRequestData: ...@@ -62,6 +64,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids}," f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens}," f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}" f"lora_request={self.lora_request}"
")") ")")
...@@ -76,6 +79,7 @@ class NewRequestData: ...@@ -76,6 +79,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids}," f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens}," f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}" f"lora_request={self.lora_request}"
")") ")")
...@@ -93,6 +97,7 @@ class CachedRequestData: ...@@ -93,6 +97,7 @@ class CachedRequestData:
new_token_ids: list[list[int]] new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]] new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int] num_computed_tokens: list[int]
num_kv_tokens: list[int]
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:
...@@ -106,6 +111,7 @@ class CachedRequestData: ...@@ -106,6 +111,7 @@ class CachedRequestData:
new_token_ids=[], new_token_ids=[],
new_block_ids=[], new_block_ids=[],
num_computed_tokens=[], num_computed_tokens=[],
num_kv_tokens=[],
) )
......
...@@ -28,12 +28,15 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy, ...@@ -28,12 +28,15 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
from vllm.v1.core.sched.utils import check_stop from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs) EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig, SlidingWindowSpec
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.kv_compression.budget import (compute_topk_budget_step,
count_prompt_must_keep_in_range)
from vllm.platforms import current_platform
from vllm import envs from vllm import envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -156,6 +159,50 @@ class Scheduler(SchedulerInterface): ...@@ -156,6 +159,50 @@ class Scheduler(SchedulerInterface):
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.full_cuda_graph = self.compilation_config.full_cuda_graph self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.use_mla = vllm_config.model_config.use_mla self.use_mla = vllm_config.model_config.use_mla
# KV compression is a GPU-only feature in this fork; ignore it on TPU.
self.kv_compression_enabled = (envs.VLLM_ENABLE_KV_COMPRESSION
and not current_platform.is_tpu())
if envs.VLLM_ENABLE_KV_COMPRESSION and current_platform.is_tpu():
logger.warning_once(
"KV compression is not supported on TPU; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1.")
if self.kv_compression_enabled:
if envs.VLLM_KV_COMPRESSION_POLICY != "topk":
raise ValueError(
"VLLM_KV_COMPRESSION_POLICY must be 'topk'.")
if any(
isinstance(group.kv_cache_spec, SlidingWindowSpec)
for group in kv_cache_config.kv_cache_groups):
raise ValueError(
"KV compression is incompatible with sliding window "
"attention.")
if self.cache_config.enable_prefix_caching:
raise ValueError(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression.")
if self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
if self.speculative_config is not None:
raise ValueError(
"KV compression is currently incompatible with "
"speculative decoding.")
if envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET < -1:
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1.")
if not (0.0 <= envs.VLLM_KV_COMPRESSION_PROMPT_RATIO <= 1.0):
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1].")
if envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW < 1:
raise ValueError(
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1.")
# Create the KV cache manager. # Create the KV cache manager.
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
...@@ -207,6 +254,8 @@ class Scheduler(SchedulerInterface): ...@@ -207,6 +254,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related. # Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# For logging. # For logging.
scheduled_timestamp = time.monotonic() scheduled_timestamp = time.monotonic()
...@@ -274,6 +323,13 @@ class Scheduler(SchedulerInterface): ...@@ -274,6 +323,13 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens - num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0) request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens == request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
...@@ -295,6 +351,7 @@ class Scheduler(SchedulerInterface): ...@@ -295,6 +351,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
...@@ -321,8 +378,12 @@ class Scheduler(SchedulerInterface): ...@@ -321,8 +378,12 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = ( if request.request_id in force_replace_block_ids:
new_blocks.get_block_ids()) req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
...@@ -532,6 +593,8 @@ class Scheduler(SchedulerInterface): ...@@ -532,6 +593,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -586,6 +649,7 @@ class Scheduler(SchedulerInterface): ...@@ -586,6 +649,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens, num_scheduled_tokens,
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids, req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
...@@ -645,6 +709,16 @@ class Scheduler(SchedulerInterface): ...@@ -645,6 +709,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related. # Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# Track the LoRAs in this step to respect max_loras when scheduling
# waiting requests first.
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in self.running
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# For logging. # For logging.
scheduled_timestamp = time.monotonic() scheduled_timestamp = time.monotonic()
...@@ -826,6 +900,8 @@ class Scheduler(SchedulerInterface): ...@@ -826,6 +900,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -894,6 +970,14 @@ class Scheduler(SchedulerInterface): ...@@ -894,6 +970,14 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens - num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0) request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens
== request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
...@@ -915,6 +999,7 @@ class Scheduler(SchedulerInterface): ...@@ -915,6 +999,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
...@@ -941,8 +1026,12 @@ class Scheduler(SchedulerInterface): ...@@ -941,8 +1026,12 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = ( if request.request_id in force_replace_block_ids:
new_blocks.get_block_ids()) req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
...@@ -1014,6 +1103,7 @@ class Scheduler(SchedulerInterface): ...@@ -1014,6 +1103,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens, num_scheduled_tokens,
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids, req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
...@@ -1076,8 +1166,51 @@ class Scheduler(SchedulerInterface): ...@@ -1076,8 +1166,51 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items(): for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id] request = self.requests[req_id]
start_pos = request.num_computed_tokens
request.num_computed_tokens += num_scheduled_token request.num_computed_tokens += num_scheduled_token
if not self.kv_compression_enabled:
# Keep KV length in sync with logical length when compression
# is disabled (default vLLM behavior).
request.num_kv_tokens += num_scheduled_token
continue
# When KV compression is enabled, only keep a subset of prompt
# tokens. Decode tokens are always kept.
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = request.num_computed_tokens
prompt_end = request.num_prompt_tokens
# Decode token(s): keep all.
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
)
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
request.num_kv_tokens += (
kept_decode + kept_prompt_must_keep + kept_prompt_topk)
# Clear the finished request IDs. # Clear the finished request IDs.
# NOTE: We shouldn't do self.finished_req_ids.clear() here because # NOTE: We shouldn't do self.finished_req_ids.clear() here because
...@@ -1091,11 +1224,16 @@ class Scheduler(SchedulerInterface): ...@@ -1091,11 +1224,16 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens: dict[str, int], num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]], spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]], req_to_new_block_ids: dict[str, tuple[list[int], ...]],
*,
force_replace_block_ids: Optional[set[str]] = None,
) -> CachedRequestData: ) -> CachedRequestData:
req_ids: list[str] = [] req_ids: list[str] = []
new_token_ids: list[list[int]] = [] new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = [] new_block_ids: list[tuple[list[int], ...]] = []
num_computed_tokens: list[int] = [] num_computed_tokens: list[int] = []
num_kv_tokens: list[int] = []
resumed_from_preemption: list[bool] = []
force_replace_block_ids = force_replace_block_ids or set()
for req in itertools.chain(running_reqs, resumed_reqs): for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id req_id = req.request_id
...@@ -1111,10 +1249,9 @@ class Scheduler(SchedulerInterface): ...@@ -1111,10 +1249,9 @@ class Scheduler(SchedulerInterface):
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id]) new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens) num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do num_kv_tokens.append(req.num_kv_tokens)
# in-place appending so that we don't need to allocate a new list. resumed_from_preemption.append(
resumed_from_preemption = [False] * len(running_reqs) (req in resumed_reqs) or (req_id in force_replace_block_ids))
resumed_from_preemption += [True] * len(resumed_reqs)
return CachedRequestData( return CachedRequestData(
req_ids=req_ids, req_ids=req_ids,
...@@ -1122,6 +1259,7 @@ class Scheduler(SchedulerInterface): ...@@ -1122,6 +1259,7 @@ class Scheduler(SchedulerInterface):
new_token_ids=new_token_ids, new_token_ids=new_token_ids,
new_block_ids=new_block_ids, new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens, num_computed_tokens=num_computed_tokens,
num_kv_tokens=num_kv_tokens,
) )
def _try_schedule_encoder_inputs( def _try_schedule_encoder_inputs(
...@@ -1567,6 +1705,7 @@ class Scheduler(SchedulerInterface): ...@@ -1567,6 +1705,7 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling. # Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
request.num_kv_tokens = num_computed_tokens
# Return that we are ready. # Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id) self.finished_recving_kv_req_ids.remove(request.request_id)
......
...@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC): ...@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC):
self.block_pool.free_blocks(ordered_blocks) self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None) self.num_cached_block.pop(request_id, None)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort optimization hook. Subclasses may override this
to free no-longer-needed blocks (e.g., after KV compaction). The default
implementation is a no-op.
"""
return False
@abstractmethod @abstractmethod
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int: num_running_requests: int) -> int:
...@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager): ...@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager):
# No need to remove blocks for full attention. # No need to remove blocks for full attention.
pass pass
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
num_tokens = max(int(num_tokens), 0)
blocks = self.req_to_blocks.get(request_id)
if not blocks:
return False
num_required_blocks = cdiv(num_tokens, self.block_size)
if num_required_blocks >= len(blocks):
return False
removed_blocks = blocks[num_required_blocks:]
del blocks[num_required_blocks:]
self.block_pool.free_blocks(reversed(removed_blocks))
if request_id in self.num_cached_block:
self.num_cached_block[request_id] = min(
self.num_cached_block[request_id], len(blocks))
return True
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int: num_running_requests: int) -> int:
blocks = self.req_to_blocks[request_id] blocks = self.req_to_blocks[request_id]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .budget import ( # noqa: F401
compute_topk_budget_step,
count_prompt_must_keep_in_range,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
def _clamp_int(value: int, lo: int, hi: int) -> int:
if value < lo:
return lo
if value > hi:
return hi
return value
def _intersection_len(a0: int, a1: int, b0: int, b1: int) -> int:
start = a0 if a0 > b0 else b0
end = a1 if a1 < b1 else b1
return max(0, end - start)
def _protected_prefix_len(prompt_len: int, protected_prefix: int) -> int:
return min(max(protected_prefix, 0), max(prompt_len, 0))
def _protected_suffix_start(prompt_len: int, protected_suffix: int) -> int:
prompt_len = max(prompt_len, 0)
suffix = min(max(protected_suffix, 0), prompt_len)
return prompt_len - suffix
def count_prompt_must_keep_in_range(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt tokens in [start_pos, end_pos) that are always kept."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
start = _clamp_int(start_pos, 0, prompt_len)
end = _clamp_int(end_pos, 0, prompt_len)
if end <= start:
return 0
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
keep_prefix = _intersection_len(start, end, 0, prefix_len)
keep_suffix = _intersection_len(start, end, suffix_start, prompt_len)
overlap = _intersection_len(start, end, suffix_start, prefix_len)
kept = keep_prefix + keep_suffix - overlap
if keep_last_token:
last = prompt_len - 1
if start <= last < end:
already_kept = (last < prefix_len) or (last >= suffix_start)
if not already_kept:
kept += 1
return kept
def _count_prompt_candidates_upto(
*,
prompt_len: int,
pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt candidates in [0, pos) eligible for Top-K selection."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
x = _clamp_int(pos, 0, prompt_len)
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
mid_end = min(x, suffix_start)
cand = max(0, mid_end - min(prefix_len, mid_end))
if keep_last_token:
last = prompt_len - 1
if prefix_len <= last < mid_end:
cand -= 1
return max(cand, 0)
def _candidate_total(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
return _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
def _candidate_keep_total(
*,
candidate_total: int,
prompt_ratio: float,
prompt_budget: int,
) -> int:
if candidate_total <= 0:
return 0
if prompt_budget >= 0:
return min(prompt_budget, candidate_total)
ratio = max(0.0, min(float(prompt_ratio), 1.0))
keep = int(math.floor(candidate_total * ratio + 0.5))
return _clamp_int(keep, 0, candidate_total)
def compute_topk_budget_step(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute how many prompt candidate tokens to select for this step.
The budget applies to the *non-protected* prompt region and is distributed
across multiple prefill steps using a prefix-proportional rule:
budget_upto(x) = floor(total_keep * candidates_upto(x) / candidates_total)
The step's budget is the delta between its end and start positions.
"""
total = _candidate_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
if total <= 0:
return 0
total_keep = _candidate_keep_total(
candidate_total=total,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
if total_keep <= 0:
return 0
cand_upto_start = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=start_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
cand_upto_end = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
step_total = max(0, cand_upto_end - cand_upto_start)
if step_total == 0:
return 0
bud_upto_start = (total_keep * cand_upto_start) // total
bud_upto_end = (total_keep * cand_upto_end) // total
step_keep = bud_upto_end - bud_upto_start
return _clamp_int(step_keep, 0, step_total)
...@@ -79,6 +79,10 @@ class Request: ...@@ -79,6 +79,10 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy() self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = [] self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0 self.num_computed_tokens = 0
# Number of tokens currently stored in the KV cache for this request.
# This can be different from `num_computed_tokens` when KV compression
# is enabled (e.g., token-shared prefill compression).
self.num_kv_tokens = 0
self.num_generated_token_ids = 0 self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt self.cache_salt: Optional[str] = cache_salt
......
...@@ -63,6 +63,11 @@ class BlockTable: ...@@ -63,6 +63,11 @@ class BlockTable:
def add_row(self, block_ids: list[int], row_idx: int) -> None: def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0 self.num_blocks_per_row[row_idx] = 0
# Keep the invariant that "unused" entries map to the null block (id=0).
# This matters when we *shrink* a request's block list (e.g. KV
# compression tail-block truncation) and later re-use freed blocks for
# other requests.
self.block_table_np[row_idx, :].fill(0)
self.append_row(block_ids, row_idx) self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None: def move_row(self, src: int, tgt: int) -> None:
......
...@@ -38,6 +38,7 @@ class CachedRequestState: ...@@ -38,6 +38,7 @@ class CachedRequestState:
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
num_kv_tokens: int
output_token_ids: list[int] output_token_ids: list[int]
spec_token_ids: list[int] = None spec_token_ids: list[int] = None
...@@ -114,6 +115,13 @@ class InputBatch: ...@@ -114,6 +115,13 @@ class InputBatch:
) )
self.num_computed_tokens_cpu = \ self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy() self.num_computed_tokens_cpu_tensor.numpy()
self.num_kv_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
# Block table. # Block table.
self.block_table = MultiGroupBlockTable( self.block_table = MultiGroupBlockTable(
...@@ -348,6 +356,7 @@ class InputBatch: ...@@ -348,6 +356,7 @@ class InputBatch:
self.num_tokens_no_spec[req_index] = request.num_tokens self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
self.block_table.add_row(request.block_ids, req_index) self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params: if sampling_params := request.sampling_params:
...@@ -504,6 +513,8 @@ class InputBatch: ...@@ -504,6 +513,8 @@ class InputBatch:
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] =\
self.num_kv_tokens_cpu[i2], self.num_kv_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\ self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1] self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\ self.top_p_cpu[i1], self.top_p_cpu[i2] =\
...@@ -602,6 +613,8 @@ class InputBatch: ...@@ -602,6 +613,8 @@ class InputBatch:
last_req_index] last_req_index]
self.num_computed_tokens_cpu[ self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index] empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.num_kv_tokens_cpu[
empty_index] = self.num_kv_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index) self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[ self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index] last_req_index]
......
...@@ -55,6 +55,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget ...@@ -55,6 +55,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec, KVCacheConfig, KVCacheSpec, MambaSpec,
SlidingWindowSpec) SlidingWindowSpec)
from vllm.v1.kv_compression.budget import compute_topk_budget_step
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
...@@ -146,6 +147,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -146,6 +147,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.attention_chunk_size = model_config.attention_chunk_size self.attention_chunk_size = model_config.attention_chunk_size
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
if envs.VLLM_ENABLE_KV_COMPRESSION:
# KV compression changes the effective KV sequence layout and
# invalidates cascade attention assumptions (common-prefix blocks).
self.cascade_attn_enabled = False
# Whether the current step needs KV compaction work (score/topk/dst).
# This is set per-step in `_prepare_inputs`.
self.kv_compression_needs_compaction: bool = False
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
...@@ -313,6 +321,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -313,6 +321,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy() self.positions_np = self.positions_cpu.numpy()
# KV positions are decoupled from logical positions when KV compression
# is enabled. We keep a separate buffer to avoid recomputing or
# overwriting `positions_np` (used for RoPE / input token lookup).
self.kv_positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.kv_positions_np = self.kv_positions_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
...@@ -323,6 +339,34 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -323,6 +339,34 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
# KV compression metadata buffers (used by the "topk" policy).
# Per-token: whether this scheduled token must be kept in KV cache.
self.kv_compression_must_keep_cpu = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_must_keep_np = self.kv_compression_must_keep_cpu.numpy()
self.kv_compression_must_keep = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device=self.device,
)
# Per-request: how many additional prompt tokens to keep among
# non-protected candidates (budget from env; selection uses scores).
self.kv_compression_topk_budget_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_topk_budget_np = self.kv_compression_topk_budget_cpu.numpy()
self.kv_compression_topk_budget = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
# Layer pairings for cross-layer KV sharing. # Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it # If an Attention layer `layer_name` is in the keys of this dict, it
...@@ -448,6 +492,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -448,6 +492,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
generator=generator, generator=generator,
block_ids=new_req_data.block_ids, block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
num_kv_tokens=new_req_data.num_kv_tokens,
output_token_ids=[], output_token_ids=[],
lora_request=new_req_data.lora_request, lora_request=new_req_data.lora_request,
) )
...@@ -497,11 +542,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -497,11 +542,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
for i, req_id in enumerate(req_data.req_ids): for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i] num_computed_tokens = req_data.num_computed_tokens[i]
num_kv_tokens = req_data.num_kv_tokens[i]
new_block_ids = req_data.new_block_ids[i] new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i] resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
req_state.num_kv_tokens = num_kv_tokens
spec_token_ids = ( spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
...@@ -545,7 +592,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -545,7 +592,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index) self.input_batch.num_kv_tokens_cpu[req_index] = num_kv_tokens
if resumed_from_preemption:
self.input_batch.block_table.add_row(new_block_ids, req_index)
else:
self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu # For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached. # because the sampled tokens are already cached.
...@@ -658,6 +709,78 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -658,6 +709,78 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
np.add(self.input_batch.num_computed_tokens_cpu[req_indices], np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange, arange,
out=positions_np) out=positions_np)
# KV positions (where the KV for each scheduled token is temporarily
# written). When KV compression is enabled, KV positions are decoupled
# from logical positions.
use_kv_compression = envs.VLLM_ENABLE_KV_COMPRESSION
if use_kv_compression:
kv_positions_np = self.kv_positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_kv_tokens_cpu[req_indices],
arange,
out=kv_positions_np)
else:
kv_positions_np = None
if use_kv_compression:
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
must_keep_np = self.kv_compression_must_keep_np[
:total_num_scheduled_tokens]
must_keep_np.fill(False)
topk_budget_np = self.kv_compression_topk_budget_np[:num_reqs]
topk_budget_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
start = 0 if req_idx == 0 else int(cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx])
assert end - start == qlen
base_pos = int(
self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
pos = base_pos + np.arange(qlen, dtype=np.int64)
prompt_mask = pos < prompt_len
# Decode tokens are always kept.
must_keep = ~prompt_mask
if np.any(prompt_mask):
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
must_keep_np[start:end] = must_keep
# Decode-only fast path: if all scheduled tokens are unconditionally
# kept and there is no Top-K budget, KV compaction is a no-op and we
# can skip score/topk/dst entirely in the attention backend.
self.kv_compression_needs_compaction = (not must_keep_np.all()) or (
topk_budget_np > 0).any()
else:
self.kv_compression_needs_compaction = False
# Calculate M-RoPE positions. # Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
...@@ -685,6 +808,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -685,6 +808,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
block_size = kv_cache_group_spec.kv_cache_spec.block_size block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = self.input_batch.block_table[ block_table: BlockTable = self.input_batch.block_table[
kv_cache_group_id] kv_cache_group_id]
slot_positions_np = (kv_positions_np
if use_kv_compression else positions_np)
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2. # where K is the max_num_blocks_per_req and the block size is 2.
...@@ -693,11 +818,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -693,11 +818,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# block_size. # block_size.
block_table_indices = ( block_table_indices = (
req_indices * block_table.max_num_blocks_per_req + req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size) slot_positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor() block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten( block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy() )[block_table_indices].numpy()
block_offsets = positions_np % block_size block_offsets = slot_positions_np % block_size
np.add( np.add(
block_numbers * block_size, block_numbers * block_size,
block_offsets, block_offsets,
...@@ -707,9 +832,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -707,9 +832,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_np[0] = 0 self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.seq_lens_np[:num_reqs] = ( if use_kv_compression:
self.input_batch.num_computed_tokens_cpu[:num_reqs] + self.seq_lens_np[:num_reqs] = (
num_scheduled_tokens) self.input_batch.num_kv_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
else:
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
# Copy the tensors to the GPU. # Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids[:total_num_scheduled_tokens].copy_(
...@@ -729,6 +859,15 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -729,6 +859,15 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True) non_blocking=True)
if use_kv_compression:
self.kv_compression_must_keep[:total_num_scheduled_tokens].copy_(
self.kv_compression_must_keep_cpu[:total_num_scheduled_tokens],
non_blocking=True,
)
self.kv_compression_topk_budget[:num_reqs].copy_(
self.kv_compression_topk_budget_cpu[:num_reqs],
non_blocking=True,
)
# Fill unused with -1. Needed for reshape_and_cache # Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0) self.seq_lens[num_reqs:].fill_(0)
...@@ -2532,6 +2671,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2532,6 +2671,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
assert len(self.attn_backends) == 0 and len( assert len(self.attn_backends) == 0 and len(
self.attn_metadata_builders self.attn_metadata_builders
) == 0, "Attention backends are already initialized" ) == 0, "Attention backends are already initialized"
if envs.VLLM_ENABLE_KV_COMPRESSION and self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
for i, kv_cache_group_spec in enumerate( for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups): kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec kv_cache_spec = kv_cache_group_spec.kv_cache_spec
...@@ -2555,7 +2698,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2555,7 +2698,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
raise NotImplementedError( raise NotImplementedError(
"Non-Attention backend is not supported by V1 " "Non-Attention backend is not supported by V1 "
"GPUModelRunner.") "GPUModelRunner.")
if (envs.VLLM_ENABLE_KV_COMPRESSION
and attn_backend_i.get_name() != "FLASH_ATTN_VLLM_V1"):
raise ValueError(
"KV compression currently requires "
"VLLM_ATTENTION_BACKEND=FLASH_ATTN_VLLM_V1.")
elif isinstance(kv_cache_spec, MambaSpec): elif isinstance(kv_cache_spec, MambaSpec):
if envs.VLLM_ENABLE_KV_COMPRESSION:
raise ValueError(
"KV compression is currently only supported for "
"Transformer attention layers.")
attn_backend_i = Mamba2AttentionBackend attn_backend_i = Mamba2AttentionBackend
else: else:
raise ValueError( raise ValueError(
...@@ -3689,4 +3841,4 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3689,4 +3841,4 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if envs.VLLM_USE_ZERO_MTP: if envs.VLLM_USE_ZERO_MTP:
GPUModelRunner=GPUModelRunnerMTP GPUModelRunner=GPUModelRunnerMTP
else: else:
GPUModelRunner=GPUModelRunnerBase GPUModelRunner=GPUModelRunnerBase
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment