Commit 2676ad00 authored by laibao's avatar laibao
Browse files

v1: add SnapKV Triton KV compression

Introduce v1 KV compression modules (budget + SnapKV Triton kernel) and integrate with scheduler/cache managers.
parent 155c8a13
......@@ -141,6 +141,34 @@ if TYPE_CHECKING:
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
# KV compression (token-shared) for v1 paged attention.
# When enabled, vLLM decouples logical positions from KV cache positions
# and keeps only a subset of prompt tokens in KV cache during prefill.
VLLM_ENABLE_KV_COMPRESSION: bool = False
# KV compression policy for selecting which prompt KV entries to retain.
# Currently only "topk" is supported.
VLLM_KV_COMPRESSION_POLICY: str = "topk"
# Target prompt KV budget for token-shared compression.
# If PROMPT_BUDGET >= 0, it takes precedence over PROMPT_RATIO.
# The budget/ratio applies to non-protected prompt tokens only.
VLLM_KV_COMPRESSION_PROMPT_RATIO: float = 1.0
VLLM_KV_COMPRESSION_PROMPT_BUDGET: int = -1
VLLM_KV_COMPRESSION_PROTECTED_PREFIX: int = 0
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX: int = 0
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN: bool = True
# SnapKV-like scoring wi这个ndow used by the "topk" policy.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW: int = 32
# Use Triton SnapKV scoring on ROCm (experimental). Set to 0 to force the
# PyTorch reference implementation.
VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM: bool = True
# If set, compute token-shared Top-K selection per attention layer instead
# of sharing a single selection across all layers in a forward pass.
VLLM_KV_COMPRESSION_TOPK_PER_LAYER: bool = False
# Run KV compaction writeback (reshape_and_cache_*) on a separate CUDA
# stream to overlap with compute (experimental).
VLLM_KV_COMPRESSION_ASYNC_WRITEBACK: bool = False
# Free unused tail KV cache blocks after prompt compaction (experimental).
VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
# add envs
......@@ -1054,6 +1082,50 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")),
# Enable token-shared KV compression for v1 paged attention (experimental).
# This feature currently targets long-prompt prefill memory reduction.
"VLLM_ENABLE_KV_COMPRESSION":
lambda: bool(int(os.getenv("VLLM_ENABLE_KV_COMPRESSION", "0"))),
# KV compression policy ("topk").
"VLLM_KV_COMPRESSION_POLICY":
lambda: os.getenv("VLLM_KV_COMPRESSION_POLICY", "topk").lower(),
# Target fraction of non-protected prompt tokens to keep in KV cache.
"VLLM_KV_COMPRESSION_PROMPT_RATIO":
lambda: float(os.getenv("VLLM_KV_COMPRESSION_PROMPT_RATIO", "1.0")),
# Target number of non-protected prompt tokens to keep in KV cache.
# If >= 0, this takes precedence over VLLM_KV_COMPRESSION_PROMPT_RATIO.
"VLLM_KV_COMPRESSION_PROMPT_BUDGET":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROMPT_BUDGET", "-1")),
# Always keep the first N prompt tokens in KV cache (e.g. BOS/system).
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROTECTED_PREFIX", "0")),
# Always keep the last N prompt tokens in KV cache.
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROTECTED_SUFFIX", "0")),
# Always keep the last prompt token (prompt_len - 1) when it is scheduled.
"VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN":
lambda: bool(int(os.getenv("VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN", "1"))),
# SnapKV-like scoring window size for the "topk" policy.
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_SNAPKV_WINDOW", "32")),
# Enable Triton SnapKV scoring on ROCm (experimental).
"VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM", "1"))),
# If set, compute token-shared Top-K selection per attention layer instead
# of sharing one selection across layers in a forward pass.
"VLLM_KV_COMPRESSION_TOPK_PER_LAYER":
lambda: bool(int(os.getenv("VLLM_KV_COMPRESSION_TOPK_PER_LAYER", "0"))),
# If set, run KV compaction writeback on a separate CUDA stream to overlap
# cache writes with compute (experimental).
"VLLM_KV_COMPRESSION_ASYNC_WRITEBACK":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_ASYNC_WRITEBACK", "0"))),
# If set, free unused tail KV cache blocks after prompt compaction.
"VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS", "0"))),
# If set, vLLM will use optimized MLA attention optimizations.
"VLLM_USE_TRITON_OPT_MLA":
......
......@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple
import threading
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union
import numpy as np
import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
......@@ -32,6 +34,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches)
......@@ -42,10 +45,31 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
_DISABLE_SNAPKV_TRITON: bool = False
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
# Dedicated stream for KV compaction writeback (reshape_and_cache_*). This is
# used to overlap cache write bandwidth with compute on the default stream.
_KV_COMPRESSION_STORE_STREAMS: dict[int, torch.cuda.Stream] = {}
_KV_COMPRESSION_STORE_STREAMS_LOCK = threading.Lock()
def _get_kv_compression_store_stream(device: torch.device) -> torch.cuda.Stream:
if device.type != "cuda":
raise ValueError("KV compression STORE_STREAM requires a CUDA device.")
device_index = device.index
if device_index is None:
device_index = torch.cuda.current_device()
with _KV_COMPRESSION_STORE_STREAMS_LOCK:
stream = _KV_COMPRESSION_STORE_STREAMS.get(device_index)
if stream is None:
with torch.cuda.device(device_index):
stream = torch.cuda.Stream()
_KV_COMPRESSION_STORE_STREAMS[device_index] = stream
return stream
class FlashAttentionBackend(AttentionBackend):
......@@ -161,6 +185,11 @@ class FlashAttentionMetadata:
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# KV compression metadata for token-shared selection.
kv_compression_must_keep: Optional[torch.Tensor] = None
kv_compression_topk_budget: Optional[torch.Tensor] = None
# CPU-known max Top-K budget for this step (avoids device->host sync).
kv_compression_topk_budget_max: Optional[int] = None
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
......@@ -268,6 +297,22 @@ class FlashAttentionMetadataBuilder(
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
kv_compression_must_keep = None
kv_compression_topk_budget = None
kv_compression_topk_budget_max: Optional[int] = None
if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.runner.kv_compression_needs_compaction):
kv_compression_must_keep = self.runner.kv_compression_must_keep[:
num_actual_tokens]
kv_compression_topk_budget = self.runner.kv_compression_topk_budget[:
num_reqs]
# Avoid device->host sync by reading from the CPU staging buffer.
if num_reqs > 0:
kv_compression_topk_budget_max = int(
self.runner.kv_compression_topk_budget_np[:num_reqs].max())
else:
kv_compression_topk_budget_max = 0
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
......@@ -426,6 +471,9 @@ class FlashAttentionMetadataBuilder(
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
kv_compression_must_keep=kv_compression_must_keep,
kv_compression_topk_budget=kv_compression_topk_budget,
kv_compression_topk_budget_max=kv_compression_topk_budget_max,
local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
......@@ -495,6 +543,10 @@ class FlashAttentionImpl(AttentionImpl):
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
# KV compression async compaction writeback state.
self._kv_compression_writeback_pending: bool = False
self._kv_compression_writeback_event: Optional[torch.cuda.Event] = None
def forward(
self,
layer: torch.nn.Module,
......@@ -543,8 +595,20 @@ class FlashAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
cache_block_size = key_cache.shape[-3]
else:
key_cache, value_cache = kv_cache
cache_block_size = key_cache.shape[-2]
# Ensure any async compaction writeback from the previous step for this
# layer has completed before we read/write this layer's KV cache again.
# This is required because the scheduler advances `num_kv_tokens`
# assuming compaction has taken effect.
if self._kv_compression_writeback_pending and query.is_cuda:
event = self._kv_compression_writeback_event
if event is not None:
torch.cuda.current_stream(device=query.device).wait_event(event)
self._kv_compression_writeback_pending = False
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
......@@ -675,6 +739,156 @@ class FlashAttentionImpl(AttentionImpl):
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache=True,
)
# Optional KV compaction pass for token-shared KV compression.
# This rewrites a selected subset of newly written KV entries into
# a packed layout for the next step.
if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.kv_sharing_target_layer_name is None):
dst = None
if (attn_metadata.kv_compression_must_keep is not None
and attn_metadata.kv_compression_topk_budget
is not None):
forward_context = get_forward_context()
per_layer_topk = envs.VLLM_KV_COMPRESSION_TOPK_PER_LAYER
if per_layer_topk:
layer_name = getattr(layer, "layer_name", None)
if layer_name is None:
layer_name = str(id(layer))
dst_by_layer = getattr(
forward_context, "_kv_compression_compact_slots_by_layer",
None)
if dst_by_layer is None:
dst_by_layer = {}
setattr(
forward_context,
"_kv_compression_compact_slots_by_layer",
dst_by_layer,
)
dst = dst_by_layer.get(layer_name)
else:
dst = getattr(forward_context,
"_kv_compression_compact_slots", None)
if dst is None:
topk_budget = attn_metadata.kv_compression_topk_budget
token_scores: Optional[torch.Tensor] = None
# If there is no Top-K budget for any request in this
# step, selection does not depend on token scores.
# Skipping SnapKV scoring avoids unnecessary compute.
topk_budget_max = int(
attn_metadata.kv_compression_topk_budget_max or 0)
if topk_budget_max > 0:
# Mixed batch optimization: avoid scoring requests
# with a zero Top-K budget by setting their
# per-request window to 0 (kernel early-return).
window = int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW)
w = torch.where(
topk_budget > 0,
torch.full_like(topk_budget, window),
torch.zeros_like(topk_budget),
)
token_scores = _snapkv_like_token_scores(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
query_start_loc=attn_metadata.query_start_loc,
window=w,
sm_scale=self.scale,
)
dst = _topk_kv_compact_slot_mapping(
token_scores=token_scores,
must_keep=attn_metadata.kv_compression_must_keep,
topk_budget=topk_budget,
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
block_table=attn_metadata.block_table,
block_size=cache_block_size,
max_query_len=attn_metadata.max_query_len,
topk_budget_max=topk_budget_max,
)
if per_layer_topk:
dst_by_layer[layer_name] = dst
else:
setattr(forward_context,
"_kv_compression_compact_slots", dst)
if dst is not None:
src = attn_metadata.slot_mapping
rewrite_mask = (dst >= 0) & (dst != src)
# Avoid host-side synchronization (`torch.any(...)`) and
# dynamic boolean-indexing gathers. Instead, construct a
# per-token destination mapping where non-rewrite tokens
# are marked as -1, which the cache kernels treat as
# padding and skip.
dst_rewrite = torch.where(rewrite_mask, dst, -1)
async_writeback = (
envs.VLLM_KV_COMPRESSION_ASYNC_WRITEBACK
and query.is_cuda
and (not hasattr(torch.cuda, "is_current_stream_capturing")
or not torch.cuda.is_current_stream_capturing()))
def _writeback(dst_mapping: torch.Tensor) -> None:
if not current_platform.is_rocm():
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
if (envs.VLLM_USE_OPT_RESHAPE_AND_CACHE
and key.dtype == value.dtype
and key.dtype == torch.float16):
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
from vllm.attention.utils.fa_utils import (
reshape_and_cache_cuda)
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if async_writeback:
store_stream = _get_kv_compression_store_stream(
query.device)
store_stream.wait_stream(
torch.cuda.current_stream(device=query.device))
if self._kv_compression_writeback_event is None:
self._kv_compression_writeback_event = torch.cuda.Event(
enable_timing=False)
# Ensure temporaries live until STORE_STREAM finishes.
key.record_stream(store_stream)
value.record_stream(store_stream)
dst_rewrite.record_stream(store_stream)
with torch.cuda.stream(store_stream):
_writeback(dst_rewrite)
self._kv_compression_writeback_event.record()
self._kv_compression_writeback_pending = True
else:
_writeback(dst_rewrite)
return output
assert not use_local_attn, (
......@@ -733,6 +947,228 @@ class FlashAttentionImpl(AttentionImpl):
return output
def _snapkv_like_token_scores(
*,
query: torch.Tensor, # [T, Hq, D]
key: torch.Tensor, # [T, Hkv, D]
query_start_loc: torch.Tensor, # [B+1]
window: Union[int, torch.Tensor],
sm_scale: float,
) -> torch.Tensor:
"""Compute token-shared SnapKV-like scores for a packed varlen batch.
Scores are computed as the attention mass from the last `window` query
tokens to the earlier keys within the same scheduled segment (per request),
summed across KV heads.
Prefers a Triton implementation when available; falls back to a (slower)
PyTorch reference implementation otherwise.
"""
global _DISABLE_SNAPKV_TRITON
device = query.device
T, Hq, D = query.shape
Hkv = key.shape[1]
if Hq % Hkv != 0:
raise ValueError("Query heads must be a multiple of KV heads.")
# NOTE: Triton SnapKV scoring on ROCm is experimental. It is enabled by
# default (uses a ROCm-safe kernel); set
# VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM=0 to force the PyTorch
# reference implementation.
if (HAS_TRITON and not _DISABLE_SNAPKV_TRITON and device.type == "cuda"
and (not current_platform.is_rocm()
or envs.VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM)
and query.stride(-1) == 1 and key.stride(-1) == 1):
try:
from vllm.v1.attention.kv_compression.snapkv_triton import (
query_aware_key_scores)
w = int(window) if isinstance(window, int) else window
scores_per_head = query_aware_key_scores(
q=query,
k=key,
cu_seqlens_q=query_start_loc,
cu_seqlens_k=query_start_loc,
w=w,
sm_scale=float(sm_scale),
pool=False,
protect_last=False,
normalize=False,
)
token_scores = scores_per_head.sum(dim=1)
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(token_scores)
except Exception as e:
_DISABLE_SNAPKV_TRITON = True
logger.warning(
"Triton SnapKV scoring failed; falling back to PyTorch. "
"Error: %s", e)
group = Hq // Hkv
# Read boundaries on host (small tensor).
qsl = query_start_loc.tolist()
B = len(qsl) - 1
wsl = None
if not isinstance(window, int):
if int(window.numel()) != B:
raise ValueError("window must be a scalar int or have shape [B].")
wsl = window.to(device="cpu", dtype=torch.int64).tolist()
scores = torch.zeros((T, ), device=device, dtype=torch.float32)
for b in range(B):
s = int(qsl[b])
e = int(qsl[b + 1])
L = e - s
if L <= 0:
continue
win_b = int(window) if wsl is None else int(wsl[b])
if win_b <= 0:
continue
win = min(win_b, L)
k_eff_end = L - win
if k_eff_end <= 0:
continue
q_win = query[e - win:e] # [win, Hq, D]
# Aggregate query heads to KV heads (token-shared selection).
q_win = q_win.reshape(win, Hkv, group, D).mean(dim=2) # [win, Hkv, D]
k_eff = key[s:s + k_eff_end] # [k_eff_end, Hkv, D]
qh = q_win.permute(1, 0, 2).to(torch.float32) # [Hkv, win, D]
kh = k_eff.permute(1, 0, 2).to(torch.float32) # [Hkv, k_eff_end, D]
logits = torch.matmul(qh, kh.transpose(1, 2)) * sm_scale # [Hkv, win, K]
probs = torch.softmax(logits, dim=-1)
# Sum over (heads, window queries) -> per-key token score.
scores[s:s + k_eff_end] = probs.sum(dim=1).sum(dim=0)
# Aggregate across tensor-parallel ranks so every rank selects the same
# token indices.
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(scores)
def _topk_kv_compact_slot_mapping(
*,
token_scores: Optional[torch.Tensor], # [T] float32
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
query_start_loc: torch.Tensor, # [B+1]
seq_lens: torch.Tensor, # [B] int32
block_table: torch.Tensor, # [B, max_blocks]
block_size: int,
max_query_len: Optional[int] = None,
topk_budget_max: Optional[int] = None,
) -> torch.Tensor:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device = must_keep.device
T = int(must_keep.numel())
B = int(topk_budget.numel())
dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
if T == 0 or B == 0:
return dst_slots
# Per-request segment boundaries in the packed [T] layout.
# NOTE: `query_start_loc` is already sliced to [B+1] by the model runner.
starts = query_start_loc[:B].to(torch.long)
ends = query_start_loc[1:B + 1].to(torch.long)
lengths = ends - starts # [B]
if lengths.numel() == 0:
return dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max = int(max_query_len) if max_query_len is not None else int(
lengths.max().item())
if L_max <= 0:
return dst_slots
# Map each token to its (request, offset-within-request) coordinate.
token_idx = torch.arange(T, device=device, dtype=torch.long)
# For monotonic `ends` (cu_seqlens), this returns the request id for each
# token in the packed layout.
# Use right=True so that idx==ends[b] maps to the *next* request (b+1),
# i.e., request segments are [start, end) in the packed layout.
req_ids = torch.bucketize(token_idx, ends, right=True) # [T]
start_per_token = starts.index_select(0, req_ids)
pos_in_req = token_idx - start_per_token # [T] in [0, L_b)
# Clamp the per-request top-k budget to the number of candidate tokens
# (excluding must_keep).
must_keep_counts = torch.zeros(B, device=device, dtype=torch.long)
must_keep_counts.scatter_add_(0, req_ids, must_keep.to(torch.long))
cand_counts = (lengths.to(torch.long) - must_keep_counts).clamp_min(0)
k_eff = torch.minimum(topk_budget.to(torch.long).clamp_min(0), cand_counts)
# Prefer an upper bound from CPU (piecewise graph), to avoid sync.
if topk_budget_max is not None:
k_max = min(int(topk_budget_max), L_max)
else:
k_max = int(k_eff.max().item())
# Build a padded [B, L_max] score matrix for a single batched Top-K call.
# Must-keep and padding positions are set to -inf to avoid selection.
keep_mask = must_keep.clone()
if k_max > 0:
if token_scores is None:
raise ValueError("token_scores must be provided when k_max > 0.")
masked_scores = token_scores.to(dtype=torch.float32).masked_fill(
must_keep, float("-inf"))
scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
linear = req_ids * L_max + pos_in_req
scores_flat[linear] = masked_scores
scores = scores_flat.view(B, L_max)
topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max]
# Select only the first k_eff[b] entries for each request b.
col_mask = torch.arange(k_max, device=device).unsqueeze(
0) < k_eff.unsqueeze(1) # [B, k_max]
# Avoid host-side synchronization from dynamic indexing. Instead, mark
# selected tokens via a fixed-size scatter-add.
global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B, k_max]
flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
flat_val = col_mask.reshape(-1).to(torch.int32)
tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
tmp.scatter_add_(0, flat_idx, flat_val)
keep_mask |= tmp > 0
# Compute segment-local ranks (0..kept-1) for kept tokens, preserving token
# order within each request, without dynamic indexing (graph-friendly).
keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T]
start_minus_1 = (starts - 1).clamp_min(0)
prefix_before_all = keep_prefix.index_select(0, start_minus_1.to(torch.long))
prefix_before = torch.where(starts > 0, prefix_before_all,
torch.zeros_like(prefix_before_all)) # [B]
prefix_before_per_token = prefix_before.index_select(0, req_ids) # [T]
local_rank = keep_prefix - prefix_before_per_token - 1 # [T]
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv = (seq_lens[:B].to(torch.long) - lengths.to(torch.long)).clamp_min(0)
base_kv_per_token = base_kv.index_select(0, req_ids) # [T]
dest_pos = base_kv_per_token + local_rank # [T]
dest_block_idx = dest_pos // block_size
dest_off = dest_pos - dest_block_idx * block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks = int(block_table.shape[1])
dest_block_idx_safe = dest_block_idx.clamp_(0, max_blocks - 1).to(torch.long)
block_nums = block_table[req_ids, dest_block_idx_safe]
dest_slot = block_nums.to(torch.long) * block_size + dest_off
dst_slots = torch.where(keep_mask, dest_slot.to(torch.int64), dst_slots)
return dst_slots
def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
from typing import Optional, Union
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import triton
import triton.language as tl
if HAS_TRITON:
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk},
num_warps=num_warps,
num_stages=num_stages,
) for bq in [32, 64] for bk in [32, 64] for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
)
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
):
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0)
s = tl.dot(q_rows, k_blk.T) * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
@triton.jit
def _lse_and_store_logits_kernel_rocm_safe(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""ROCm-safe variant of `_lse_and_store_logits_kernel`.
On some ROCm + Triton (HIP) stacks we have observed memory corruption
from the tl.dot-based implementation. This variant avoids `tl.dot` and
instead computes dot-products via explicit outer-product accumulation.
"""
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
# Accumulate s = Q @ K^T in fp32 via outer products.
s = tl.zeros([BLOCK_Q, BLOCK_K], dtype=tl.float32)
for ds in tl.static_range(0, D, BLOCK_D):
offs_d = ds + tl.arange(0, BLOCK_D)
dmask = offs_d < D
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_chunk = tl.load(
q_ptrs,
mask=row_mask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BQ, BD]
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_chunk = tl.load(
k_ptrs,
mask=kmask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BK, BD]
s += tl.sum(q_chunk[:, None, :] * k_chunk[None, :, :], axis=2)
s = s * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
m,
mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
S,
mask=row_mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64] for bk in [32, 64, 128]
],
key=["HK", "HQ"],
)
@triton.jit
def _scores_from_logits_kernel(
cu_k,
w_b,
in_m,
in_S,
LOGITS,
OUT,
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
DO_POOL: tl.constexpr,
KPOOL: tl.constexpr,
PROTECT_LAST: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
scores = tl.zeros([BLOCK_K], dtype=tl.float32)
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
m_ptr = (in_m + b * STRIDE_M_B + hk * STRIDE_M_H +
row0 * STRIDE_M_R)
S_ptr = (in_S + b * STRIDE_S_B + hk * STRIDE_S_H +
row0 * STRIDE_S_R)
m = tl.load(m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"))
S = tl.load(S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
mask=rmask,
other=0.0)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] *
STRIDE_LG_R)
s_T = tl.load(log_ptrs,
mask=kmask[:, None] & rmask[None, :],
other=-float("inf"))
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
scores += tl.sum(probs_T, 1)
if DO_POOL and (KPOOL > 1):
i = tl.arange(0, BLOCK_K)[:, None]
j = tl.arange(0, BLOCK_K)[None, :]
band = (j <= i) & ((i - j) < KPOOL)
band = band & kmask[None, :]
sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1)
denom = tl.sum(band, 1).to(tl.float32)
denom = tl.where(denom > 0, denom, 1.0)
scores = sums / denom
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, scores, mask=kmask)
if PROTECT_LAST:
pad_beg = k_eff_end
pad_end = k_end
if pad_end > pad_beg:
for ks in tl.range(pad_beg, pad_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < pad_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs,
tl.full([BLOCK_K], float("inf"), dtype=tl.float32),
mask=kmask)
@triton.autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT,
cu_k,
w_b,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr,
EPS: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1] int32
cu_seqlens_k: torch.Tensor, # [B+1] int32
w: Union[int, torch.Tensor], # [B] int32 or scalar
sm_scale: Optional[float] = None,
*,
pool: bool = True,
kpool: int = 5,
protect_last: bool = True,
normalize: bool = False,
) -> torch.Tensor:
"""SnapKV query-aware key scores (Triton), returns [N_k, Hk] float32."""
if not HAS_TRITON:
raise RuntimeError("Triton is not available.")
if q.device.type != "cuda" or k.device.type != "cuda":
raise RuntimeError("Triton SnapKV requires CUDA/ROCm tensors.")
if q.ndim != 3 or k.ndim != 3:
raise ValueError("q and k must be 3D tensors.")
if q.stride(-1) != 1 or k.stride(-1) != 1:
raise ValueError("Last dim must be contiguous for Triton SnapKV.")
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
if D != Dk:
raise ValueError("q and k must have the same head size.")
if (Hq % Hk) != 0:
raise ValueError("Hq must be a multiple of Hk.")
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = int(cu_seqlens_q.numel() - 1)
if B != int(cu_seqlens_k.numel() - 1):
raise ValueError("cu_seqlens_q and cu_seqlens_k must match.")
G = Hq // Hk
if isinstance(w, int):
max_w = int(w)
w = torch.full((B, ),
fill_value=max_w,
device=device,
dtype=torch.int32)
else:
if w.numel() != B:
raise ValueError("w must have shape [B].")
w = w.to(device=device, dtype=torch.int32)
max_w = int(w.max().item())
rows_max = max_w * G
if rows_max <= 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
if kpool < 1:
raise ValueError("kpool must be >= 1.")
out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
S_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
logits_buf = torch.empty((N_k, Hk, rows_max),
dtype=torch.float32,
device=device)
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(meta):
return B, Hk, triton.cdiv(rows_max, meta["BLOCK_Q"])
cu_q = cu_seqlens_q.to(device=device, dtype=torch.int32)
cu_k = cu_seqlens_k.to(device=device, dtype=torch.int32)
# NOTE: On ROCm/HIP, we prefer a dot-free kernel variant to avoid known
# correctness issues (silent memory corruption) observed with the tl.dot
# implementation on some stacks.
is_rocm = getattr(torch.version, "hip", None) is not None
if is_rocm:
_lse_and_store_logits_kernel_rocm_safe[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
BLOCK_Q=32,
BLOCK_K=32,
BLOCK_D=16,
num_warps=4,
num_stages=1,
)
else:
_lse_and_store_logits_kernel[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=rows_max,
)
_scores_from_logits_kernel[(B, Hk)](
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
out,
QUERY_GROUP_SIZE=G,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_OUT_NK=STRIDE_OUT_NK,
STRIDE_OUT_HK=STRIDE_OUT_HK,
DO_POOL=pool,
KPOOL=kpool,
PROTECT_LAST=protect_last,
)
if normalize:
_zscore_per_batch_epilogue[(B, )](
out,
cu_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
return out
......@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
Returns True if any blocks were freed.
"""
truncated = False
for manager in self.single_type_managers:
truncated = manager.truncate_to_num_tokens(request_id,
num_tokens) or truncated
return truncated
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""
Get the blocks for the request.
......
......@@ -7,6 +7,8 @@ from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import sha256
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
......@@ -251,9 +253,17 @@ class KVCacheManager:
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
num_new_computed_tokens)
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
if envs.VLLM_ENABLE_KV_COMPRESSION and not current_platform.is_tpu():
# KV compression decouples logical positions from KV cache
# positions. Allocate based on the KV cache length (plus the tokens
# scheduled for this step, which are temporarily written to cache).
num_tokens_need_slot = min(
request.num_kv_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
else:
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id,
......@@ -385,6 +395,14 @@ class KVCacheManager:
return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids()
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort operation that may free blocks back to the pool.
Returns True if any blocks were freed.
"""
return self.coordinator.truncate_to_num_tokens(request_id, num_tokens)
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled."""
if self.enable_caching:
......
......@@ -31,6 +31,7 @@ class NewRequestData:
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
num_kv_tokens: int
lora_request: Optional[LoRARequest]
@classmethod
......@@ -49,6 +50,7 @@ class NewRequestData:
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
num_kv_tokens=request.num_kv_tokens,
lora_request=request.lora_request,
)
......@@ -62,6 +64,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}"
")")
......@@ -76,6 +79,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}"
")")
......@@ -93,6 +97,7 @@ class CachedRequestData:
new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int]
num_kv_tokens: list[int]
@property
def num_reqs(self) -> int:
......@@ -106,6 +111,7 @@ class CachedRequestData:
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
num_kv_tokens=[],
)
......
......@@ -28,12 +28,15 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, SlidingWindowSpec
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.kv_compression.budget import (compute_topk_budget_step,
count_prompt_must_keep_in_range)
from vllm.platforms import current_platform
from vllm import envs
logger = init_logger(__name__)
......@@ -156,6 +159,50 @@ class Scheduler(SchedulerInterface):
self.compilation_config = vllm_config.compilation_config
self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.use_mla = vllm_config.model_config.use_mla
# KV compression is a GPU-only feature in this fork; ignore it on TPU.
self.kv_compression_enabled = (envs.VLLM_ENABLE_KV_COMPRESSION
and not current_platform.is_tpu())
if envs.VLLM_ENABLE_KV_COMPRESSION and current_platform.is_tpu():
logger.warning_once(
"KV compression is not supported on TPU; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1.")
if self.kv_compression_enabled:
if envs.VLLM_KV_COMPRESSION_POLICY != "topk":
raise ValueError(
"VLLM_KV_COMPRESSION_POLICY must be 'topk'.")
if any(
isinstance(group.kv_cache_spec, SlidingWindowSpec)
for group in kv_cache_config.kv_cache_groups):
raise ValueError(
"KV compression is incompatible with sliding window "
"attention.")
if self.cache_config.enable_prefix_caching:
raise ValueError(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression.")
if self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
if self.speculative_config is not None:
raise ValueError(
"KV compression is currently incompatible with "
"speculative decoding.")
if envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET < -1:
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1.")
if not (0.0 <= envs.VLLM_KV_COMPRESSION_PROMPT_RATIO <= 1.0):
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1].")
if envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW < 1:
raise ValueError(
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1.")
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
......@@ -207,6 +254,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# For logging.
scheduled_timestamp = time.monotonic()
......@@ -274,6 +323,13 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens == request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
......@@ -295,6 +351,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
......@@ -321,8 +378,12 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
if request.request_id in force_replace_block_ids:
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
......@@ -532,6 +593,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
......@@ -586,6 +649,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
......@@ -645,6 +709,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# Track the LoRAs in this step to respect max_loras when scheduling
# waiting requests first.
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in self.running
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# For logging.
scheduled_timestamp = time.monotonic()
......@@ -826,6 +900,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
......@@ -894,6 +970,14 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens
== request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
......@@ -915,6 +999,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
......@@ -941,8 +1026,12 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
if request.request_id in force_replace_block_ids:
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
......@@ -1014,6 +1103,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
......@@ -1076,8 +1166,51 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id]
start_pos = request.num_computed_tokens
request.num_computed_tokens += num_scheduled_token
if not self.kv_compression_enabled:
# Keep KV length in sync with logical length when compression
# is disabled (default vLLM behavior).
request.num_kv_tokens += num_scheduled_token
continue
# When KV compression is enabled, only keep a subset of prompt
# tokens. Decode tokens are always kept.
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = request.num_computed_tokens
prompt_end = request.num_prompt_tokens
# Decode token(s): keep all.
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
)
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
request.num_kv_tokens += (
kept_decode + kept_prompt_must_keep + kept_prompt_topk)
# Clear the finished request IDs.
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
......@@ -1091,11 +1224,16 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
*,
force_replace_block_ids: Optional[set[str]] = None,
) -> CachedRequestData:
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = []
num_computed_tokens: list[int] = []
num_kv_tokens: list[int] = []
resumed_from_preemption: list[bool] = []
force_replace_block_ids = force_replace_block_ids or set()
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
......@@ -1111,10 +1249,9 @@ class Scheduler(SchedulerInterface):
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption += [True] * len(resumed_reqs)
num_kv_tokens.append(req.num_kv_tokens)
resumed_from_preemption.append(
(req in resumed_reqs) or (req_id in force_replace_block_ids))
return CachedRequestData(
req_ids=req_ids,
......@@ -1122,6 +1259,7 @@ class Scheduler(SchedulerInterface):
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_kv_tokens=num_kv_tokens,
)
def _try_schedule_encoder_inputs(
......@@ -1567,6 +1705,7 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
request.num_kv_tokens = num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
......
......@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC):
self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort optimization hook. Subclasses may override this
to free no-longer-needed blocks (e.g., after KV compaction). The default
implementation is a no-op.
"""
return False
@abstractmethod
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
......@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager):
# No need to remove blocks for full attention.
pass
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
num_tokens = max(int(num_tokens), 0)
blocks = self.req_to_blocks.get(request_id)
if not blocks:
return False
num_required_blocks = cdiv(num_tokens, self.block_size)
if num_required_blocks >= len(blocks):
return False
removed_blocks = blocks[num_required_blocks:]
del blocks[num_required_blocks:]
self.block_pool.free_blocks(reversed(removed_blocks))
if request_id in self.num_cached_block:
self.num_cached_block[request_id] = min(
self.num_cached_block[request_id], len(blocks))
return True
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
blocks = self.req_to_blocks[request_id]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .budget import ( # noqa: F401
compute_topk_budget_step,
count_prompt_must_keep_in_range,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
def _clamp_int(value: int, lo: int, hi: int) -> int:
if value < lo:
return lo
if value > hi:
return hi
return value
def _intersection_len(a0: int, a1: int, b0: int, b1: int) -> int:
start = a0 if a0 > b0 else b0
end = a1 if a1 < b1 else b1
return max(0, end - start)
def _protected_prefix_len(prompt_len: int, protected_prefix: int) -> int:
return min(max(protected_prefix, 0), max(prompt_len, 0))
def _protected_suffix_start(prompt_len: int, protected_suffix: int) -> int:
prompt_len = max(prompt_len, 0)
suffix = min(max(protected_suffix, 0), prompt_len)
return prompt_len - suffix
def count_prompt_must_keep_in_range(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt tokens in [start_pos, end_pos) that are always kept."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
start = _clamp_int(start_pos, 0, prompt_len)
end = _clamp_int(end_pos, 0, prompt_len)
if end <= start:
return 0
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
keep_prefix = _intersection_len(start, end, 0, prefix_len)
keep_suffix = _intersection_len(start, end, suffix_start, prompt_len)
overlap = _intersection_len(start, end, suffix_start, prefix_len)
kept = keep_prefix + keep_suffix - overlap
if keep_last_token:
last = prompt_len - 1
if start <= last < end:
already_kept = (last < prefix_len) or (last >= suffix_start)
if not already_kept:
kept += 1
return kept
def _count_prompt_candidates_upto(
*,
prompt_len: int,
pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt candidates in [0, pos) eligible for Top-K selection."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
x = _clamp_int(pos, 0, prompt_len)
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
mid_end = min(x, suffix_start)
cand = max(0, mid_end - min(prefix_len, mid_end))
if keep_last_token:
last = prompt_len - 1
if prefix_len <= last < mid_end:
cand -= 1
return max(cand, 0)
def _candidate_total(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
return _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
def _candidate_keep_total(
*,
candidate_total: int,
prompt_ratio: float,
prompt_budget: int,
) -> int:
if candidate_total <= 0:
return 0
if prompt_budget >= 0:
return min(prompt_budget, candidate_total)
ratio = max(0.0, min(float(prompt_ratio), 1.0))
keep = int(math.floor(candidate_total * ratio + 0.5))
return _clamp_int(keep, 0, candidate_total)
def compute_topk_budget_step(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute how many prompt candidate tokens to select for this step.
The budget applies to the *non-protected* prompt region and is distributed
across multiple prefill steps using a prefix-proportional rule:
budget_upto(x) = floor(total_keep * candidates_upto(x) / candidates_total)
The step's budget is the delta between its end and start positions.
"""
total = _candidate_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
if total <= 0:
return 0
total_keep = _candidate_keep_total(
candidate_total=total,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
if total_keep <= 0:
return 0
cand_upto_start = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=start_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
cand_upto_end = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
step_total = max(0, cand_upto_end - cand_upto_start)
if step_total == 0:
return 0
bud_upto_start = (total_keep * cand_upto_start) // total
bud_upto_end = (total_keep * cand_upto_end) // total
step_keep = bud_upto_end - bud_upto_start
return _clamp_int(step_keep, 0, step_total)
......@@ -79,6 +79,10 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
# Number of tokens currently stored in the KV cache for this request.
# This can be different from `num_computed_tokens` when KV compression
# is enabled (e.g., token-shared prefill compression).
self.num_kv_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt
......
......@@ -63,6 +63,11 @@ class BlockTable:
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
# Keep the invariant that "unused" entries map to the null block (id=0).
# This matters when we *shrink* a request's block list (e.g. KV
# compression tail-block truncation) and later re-use freed blocks for
# other requests.
self.block_table_np[row_idx, :].fill(0)
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
......
......@@ -38,6 +38,7 @@ class CachedRequestState:
block_ids: tuple[list[int], ...]
num_computed_tokens: int
num_kv_tokens: int
output_token_ids: list[int]
spec_token_ids: list[int] = None
......@@ -114,6 +115,13 @@ class InputBatch:
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
self.num_kv_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
......@@ -348,6 +356,7 @@ class InputBatch:
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
......@@ -504,6 +513,8 @@ class InputBatch:
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] =\
self.num_kv_tokens_cpu[i2], self.num_kv_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
......@@ -602,6 +613,8 @@ class InputBatch:
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.num_kv_tokens_cpu[
empty_index] = self.num_kv_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
......
......@@ -55,6 +55,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec,
SlidingWindowSpec)
from vllm.v1.kv_compression.budget import compute_topk_budget_step
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
......@@ -146,6 +147,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.attention_chunk_size = model_config.attention_chunk_size
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
if envs.VLLM_ENABLE_KV_COMPRESSION:
# KV compression changes the effective KV sequence layout and
# invalidates cascade attention assumptions (common-prefix blocks).
self.cascade_attn_enabled = False
# Whether the current step needs KV compaction work (score/topk/dst).
# This is set per-step in `_prepare_inputs`.
self.kv_compression_needs_compaction: bool = False
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
......@@ -313,6 +321,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
# KV positions are decoupled from logical positions when KV compression
# is enabled. We keep a separate buffer to avoid recomputing or
# overwriting `positions_np` (used for RoPE / input token lookup).
self.kv_positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.kv_positions_np = self.kv_positions_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
......@@ -323,6 +339,34 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# KV compression metadata buffers (used by the "topk" policy).
# Per-token: whether this scheduled token must be kept in KV cache.
self.kv_compression_must_keep_cpu = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_must_keep_np = self.kv_compression_must_keep_cpu.numpy()
self.kv_compression_must_keep = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device=self.device,
)
# Per-request: how many additional prompt tokens to keep among
# non-protected candidates (budget from env; selection uses scores).
self.kv_compression_topk_budget_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_topk_budget_np = self.kv_compression_topk_budget_cpu.numpy()
self.kv_compression_topk_budget = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
......@@ -448,6 +492,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
num_kv_tokens=new_req_data.num_kv_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
)
......@@ -497,11 +542,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
num_kv_tokens = req_data.num_kv_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
req_state.num_kv_tokens = num_kv_tokens
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
......@@ -545,7 +592,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index)
self.input_batch.num_kv_tokens_cpu[req_index] = num_kv_tokens
if resumed_from_preemption:
self.input_batch.block_table.add_row(new_block_ids, req_index)
else:
self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
......@@ -658,6 +709,78 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# KV positions (where the KV for each scheduled token is temporarily
# written). When KV compression is enabled, KV positions are decoupled
# from logical positions.
use_kv_compression = envs.VLLM_ENABLE_KV_COMPRESSION
if use_kv_compression:
kv_positions_np = self.kv_positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_kv_tokens_cpu[req_indices],
arange,
out=kv_positions_np)
else:
kv_positions_np = None
if use_kv_compression:
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
must_keep_np = self.kv_compression_must_keep_np[
:total_num_scheduled_tokens]
must_keep_np.fill(False)
topk_budget_np = self.kv_compression_topk_budget_np[:num_reqs]
topk_budget_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
start = 0 if req_idx == 0 else int(cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx])
assert end - start == qlen
base_pos = int(
self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
pos = base_pos + np.arange(qlen, dtype=np.int64)
prompt_mask = pos < prompt_len
# Decode tokens are always kept.
must_keep = ~prompt_mask
if np.any(prompt_mask):
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
must_keep_np[start:end] = must_keep
# Decode-only fast path: if all scheduled tokens are unconditionally
# kept and there is no Top-K budget, KV compaction is a no-op and we
# can skip score/topk/dst entirely in the attention backend.
self.kv_compression_needs_compaction = (not must_keep_np.all()) or (
topk_budget_np > 0).any()
else:
self.kv_compression_needs_compaction = False
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
......@@ -685,6 +808,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = self.input_batch.block_table[
kv_cache_group_id]
slot_positions_np = (kv_positions_np
if use_kv_compression else positions_np)
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
......@@ -693,11 +818,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size)
slot_positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy()
block_offsets = positions_np % block_size
block_offsets = slot_positions_np % block_size
np.add(
block_numbers * block_size,
block_offsets,
......@@ -707,9 +832,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
if use_kv_compression:
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_kv_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
else:
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
......@@ -729,6 +859,15 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
if use_kv_compression:
self.kv_compression_must_keep[:total_num_scheduled_tokens].copy_(
self.kv_compression_must_keep_cpu[:total_num_scheduled_tokens],
non_blocking=True,
)
self.kv_compression_topk_budget[:num_reqs].copy_(
self.kv_compression_topk_budget_cpu[:num_reqs],
non_blocking=True,
)
# Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0)
......@@ -2532,6 +2671,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
assert len(self.attn_backends) == 0 and len(
self.attn_metadata_builders
) == 0, "Attention backends are already initialized"
if envs.VLLM_ENABLE_KV_COMPRESSION and self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
......@@ -2555,7 +2698,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
if (envs.VLLM_ENABLE_KV_COMPRESSION
and attn_backend_i.get_name() != "FLASH_ATTN_VLLM_V1"):
raise ValueError(
"KV compression currently requires "
"VLLM_ATTENTION_BACKEND=FLASH_ATTN_VLLM_V1.")
elif isinstance(kv_cache_spec, MambaSpec):
if envs.VLLM_ENABLE_KV_COMPRESSION:
raise ValueError(
"KV compression is currently only supported for "
"Transformer attention layers.")
attn_backend_i = Mamba2AttentionBackend
else:
raise ValueError(
......@@ -3689,4 +3841,4 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if envs.VLLM_USE_ZERO_MTP:
GPUModelRunner=GPUModelRunnerMTP
else:
GPUModelRunner=GPUModelRunnerBase
\ No newline at end of file
GPUModelRunner=GPUModelRunnerBase
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment