Commit 2676ad00 authored by laibao's avatar laibao
Browse files

v1: add SnapKV Triton KV compression

Introduce v1 KV compression modules (budget + SnapKV Triton kernel) and integrate with scheduler/cache managers.
parent 155c8a13
...@@ -141,6 +141,34 @@ if TYPE_CHECKING: ...@@ -141,6 +141,34 @@ if TYPE_CHECKING:
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
# KV compression (token-shared) for v1 paged attention.
# When enabled, vLLM decouples logical positions from KV cache positions
# and keeps only a subset of prompt tokens in KV cache during prefill.
VLLM_ENABLE_KV_COMPRESSION: bool = False
# KV compression policy for selecting which prompt KV entries to retain.
# Currently only "topk" is supported.
VLLM_KV_COMPRESSION_POLICY: str = "topk"
# Target prompt KV budget for token-shared compression.
# If PROMPT_BUDGET >= 0, it takes precedence over PROMPT_RATIO.
# The budget/ratio applies to non-protected prompt tokens only.
VLLM_KV_COMPRESSION_PROMPT_RATIO: float = 1.0
VLLM_KV_COMPRESSION_PROMPT_BUDGET: int = -1
VLLM_KV_COMPRESSION_PROTECTED_PREFIX: int = 0
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX: int = 0
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN: bool = True
# SnapKV-like scoring wi这个ndow used by the "topk" policy.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW: int = 32
# Use Triton SnapKV scoring on ROCm (experimental). Set to 0 to force the
# PyTorch reference implementation.
VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM: bool = True
# If set, compute token-shared Top-K selection per attention layer instead
# of sharing a single selection across all layers in a forward pass.
VLLM_KV_COMPRESSION_TOPK_PER_LAYER: bool = False
# Run KV compaction writeback (reshape_and_cache_*) on a separate CUDA
# stream to overlap with compute (experimental).
VLLM_KV_COMPRESSION_ASYNC_WRITEBACK: bool = False
# Free unused tail KV cache blocks after prompt compaction (experimental).
VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
# add envs # add envs
...@@ -1055,6 +1083,50 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1055,6 +1083,50 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")), ("true", "1")),
# Enable token-shared KV compression for v1 paged attention (experimental).
# This feature currently targets long-prompt prefill memory reduction.
"VLLM_ENABLE_KV_COMPRESSION":
lambda: bool(int(os.getenv("VLLM_ENABLE_KV_COMPRESSION", "0"))),
# KV compression policy ("topk").
"VLLM_KV_COMPRESSION_POLICY":
lambda: os.getenv("VLLM_KV_COMPRESSION_POLICY", "topk").lower(),
# Target fraction of non-protected prompt tokens to keep in KV cache.
"VLLM_KV_COMPRESSION_PROMPT_RATIO":
lambda: float(os.getenv("VLLM_KV_COMPRESSION_PROMPT_RATIO", "1.0")),
# Target number of non-protected prompt tokens to keep in KV cache.
# If >= 0, this takes precedence over VLLM_KV_COMPRESSION_PROMPT_RATIO.
"VLLM_KV_COMPRESSION_PROMPT_BUDGET":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROMPT_BUDGET", "-1")),
# Always keep the first N prompt tokens in KV cache (e.g. BOS/system).
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROTECTED_PREFIX", "0")),
# Always keep the last N prompt tokens in KV cache.
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROTECTED_SUFFIX", "0")),
# Always keep the last prompt token (prompt_len - 1) when it is scheduled.
"VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN":
lambda: bool(int(os.getenv("VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN", "1"))),
# SnapKV-like scoring window size for the "topk" policy.
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_SNAPKV_WINDOW", "32")),
# Enable Triton SnapKV scoring on ROCm (experimental).
"VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM", "1"))),
# If set, compute token-shared Top-K selection per attention layer instead
# of sharing one selection across layers in a forward pass.
"VLLM_KV_COMPRESSION_TOPK_PER_LAYER":
lambda: bool(int(os.getenv("VLLM_KV_COMPRESSION_TOPK_PER_LAYER", "0"))),
# If set, run KV compaction writeback on a separate CUDA stream to overlap
# cache writes with compute (experimental).
"VLLM_KV_COMPRESSION_ASYNC_WRITEBACK":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_ASYNC_WRITEBACK", "0"))),
# If set, free unused tail KV cache blocks after prompt compaction.
"VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS", "0"))),
# If set, vLLM will use optimized MLA attention optimizations. # If set, vLLM will use optimized MLA attention optimizations.
"VLLM_USE_TRITON_OPT_MLA": "VLLM_USE_TRITON_OPT_MLA":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))), lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))),
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple import threading
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType,
...@@ -32,6 +34,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config ...@@ -32,6 +34,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches) make_local_attention_virtual_batches)
...@@ -42,10 +45,31 @@ if TYPE_CHECKING: ...@@ -42,10 +45,31 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
_DISABLE_SNAPKV_TRITON: bool = False
# NOTE(woosuk): This is an arbitrary number. Tune it if needed. # NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
# Dedicated stream for KV compaction writeback (reshape_and_cache_*). This is
# used to overlap cache write bandwidth with compute on the default stream.
_KV_COMPRESSION_STORE_STREAMS: dict[int, torch.cuda.Stream] = {}
_KV_COMPRESSION_STORE_STREAMS_LOCK = threading.Lock()
def _get_kv_compression_store_stream(device: torch.device) -> torch.cuda.Stream:
if device.type != "cuda":
raise ValueError("KV compression STORE_STREAM requires a CUDA device.")
device_index = device.index
if device_index is None:
device_index = torch.cuda.current_device()
with _KV_COMPRESSION_STORE_STREAMS_LOCK:
stream = _KV_COMPRESSION_STORE_STREAMS.get(device_index)
if stream is None:
with torch.cuda.device(device_index):
stream = torch.cuda.Stream()
_KV_COMPRESSION_STORE_STREAMS[device_index] = stream
return stream
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
...@@ -161,6 +185,11 @@ class FlashAttentionMetadata: ...@@ -161,6 +185,11 @@ class FlashAttentionMetadata:
cu_prefix_query_lens: Optional[torch.Tensor] cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor] prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor]
# KV compression metadata for token-shared selection.
kv_compression_must_keep: Optional[torch.Tensor] = None
kv_compression_topk_budget: Optional[torch.Tensor] = None
# CPU-known max Top-K budget for this step (avoids device->host sync).
kv_compression_topk_budget_max: Optional[int] = None
# Optional aot scheduling # Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None scheduler_metadata: Optional[torch.Tensor] = None
...@@ -268,6 +297,22 @@ class FlashAttentionMetadataBuilder( ...@@ -268,6 +297,22 @@ class FlashAttentionMetadataBuilder(
slot_mapping = block_table.slot_mapping[:num_actual_tokens] slot_mapping = block_table.slot_mapping[:num_actual_tokens]
kv_compression_must_keep = None
kv_compression_topk_budget = None
kv_compression_topk_budget_max: Optional[int] = None
if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.runner.kv_compression_needs_compaction):
kv_compression_must_keep = self.runner.kv_compression_must_keep[:
num_actual_tokens]
kv_compression_topk_budget = self.runner.kv_compression_topk_budget[:
num_reqs]
# Avoid device->host sync by reading from the CPU staging buffer.
if num_reqs > 0:
kv_compression_topk_budget_max = int(
self.runner.kv_compression_topk_budget_np[:num_reqs].max())
else:
kv_compression_topk_budget_max = 0
if self.aot_sliding_window is None: if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1) self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be # For the AOT scheduler we need the sliding window value to be
...@@ -426,6 +471,9 @@ class FlashAttentionMetadataBuilder( ...@@ -426,6 +471,9 @@ class FlashAttentionMetadataBuilder(
cu_prefix_query_lens=cu_prefix_query_lens, cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens, prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens, suffix_kv_lens=suffix_kv_lens,
kv_compression_must_keep=kv_compression_must_keep,
kv_compression_topk_budget=kv_compression_topk_budget,
kv_compression_topk_budget_max=kv_compression_topk_budget_max,
local_attn_metadata=local_attn_metadata, local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits, max_num_splits=max_num_splits,
...@@ -495,6 +543,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -495,6 +543,10 @@ class FlashAttentionImpl(AttentionImpl):
raise NotImplementedError( raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.") "FlashAttention does not support fp8 kv-cache on this device.")
# KV compression async compaction writeback state.
self._kv_compression_writeback_pending: bool = False
self._kv_compression_writeback_event: Optional[torch.cuda.Event] = None
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -543,8 +595,20 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -543,8 +595,20 @@ class FlashAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
if not current_platform.is_rocm(): if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
cache_block_size = key_cache.shape[-3]
else: else:
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
cache_block_size = key_cache.shape[-2]
# Ensure any async compaction writeback from the previous step for this
# layer has completed before we read/write this layer's KV cache again.
# This is required because the scheduler advances `num_kv_tokens`
# assuming compaction has taken effect.
if self._kv_compression_writeback_pending and query.is_cuda:
event = self._kv_compression_writeback_event
if event is not None:
torch.cuda.current_stream(device=query.device).wait_event(event)
self._kv_compression_writeback_pending = False
if self.kv_sharing_target_layer_name is None: if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
...@@ -675,6 +739,156 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -675,6 +739,156 @@ class FlashAttentionImpl(AttentionImpl):
# num_splits=attn_metadata.max_num_splits, # num_splits=attn_metadata.max_num_splits,
is_prefix_cache=True, is_prefix_cache=True,
) )
# Optional KV compaction pass for token-shared KV compression.
# This rewrites a selected subset of newly written KV entries into
# a packed layout for the next step.
if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.kv_sharing_target_layer_name is None):
dst = None
if (attn_metadata.kv_compression_must_keep is not None
and attn_metadata.kv_compression_topk_budget
is not None):
forward_context = get_forward_context()
per_layer_topk = envs.VLLM_KV_COMPRESSION_TOPK_PER_LAYER
if per_layer_topk:
layer_name = getattr(layer, "layer_name", None)
if layer_name is None:
layer_name = str(id(layer))
dst_by_layer = getattr(
forward_context, "_kv_compression_compact_slots_by_layer",
None)
if dst_by_layer is None:
dst_by_layer = {}
setattr(
forward_context,
"_kv_compression_compact_slots_by_layer",
dst_by_layer,
)
dst = dst_by_layer.get(layer_name)
else:
dst = getattr(forward_context,
"_kv_compression_compact_slots", None)
if dst is None:
topk_budget = attn_metadata.kv_compression_topk_budget
token_scores: Optional[torch.Tensor] = None
# If there is no Top-K budget for any request in this
# step, selection does not depend on token scores.
# Skipping SnapKV scoring avoids unnecessary compute.
topk_budget_max = int(
attn_metadata.kv_compression_topk_budget_max or 0)
if topk_budget_max > 0:
# Mixed batch optimization: avoid scoring requests
# with a zero Top-K budget by setting their
# per-request window to 0 (kernel early-return).
window = int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW)
w = torch.where(
topk_budget > 0,
torch.full_like(topk_budget, window),
torch.zeros_like(topk_budget),
)
token_scores = _snapkv_like_token_scores(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
query_start_loc=attn_metadata.query_start_loc,
window=w,
sm_scale=self.scale,
)
dst = _topk_kv_compact_slot_mapping(
token_scores=token_scores,
must_keep=attn_metadata.kv_compression_must_keep,
topk_budget=topk_budget,
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
block_table=attn_metadata.block_table,
block_size=cache_block_size,
max_query_len=attn_metadata.max_query_len,
topk_budget_max=topk_budget_max,
)
if per_layer_topk:
dst_by_layer[layer_name] = dst
else:
setattr(forward_context,
"_kv_compression_compact_slots", dst)
if dst is not None:
src = attn_metadata.slot_mapping
rewrite_mask = (dst >= 0) & (dst != src)
# Avoid host-side synchronization (`torch.any(...)`) and
# dynamic boolean-indexing gathers. Instead, construct a
# per-token destination mapping where non-rewrite tokens
# are marked as -1, which the cache kernels treat as
# padding and skip.
dst_rewrite = torch.where(rewrite_mask, dst, -1)
async_writeback = (
envs.VLLM_KV_COMPRESSION_ASYNC_WRITEBACK
and query.is_cuda
and (not hasattr(torch.cuda, "is_current_stream_capturing")
or not torch.cuda.is_current_stream_capturing()))
def _writeback(dst_mapping: torch.Tensor) -> None:
if not current_platform.is_rocm():
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
if (envs.VLLM_USE_OPT_RESHAPE_AND_CACHE
and key.dtype == value.dtype
and key.dtype == torch.float16):
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
from vllm.attention.utils.fa_utils import (
reshape_and_cache_cuda)
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if async_writeback:
store_stream = _get_kv_compression_store_stream(
query.device)
store_stream.wait_stream(
torch.cuda.current_stream(device=query.device))
if self._kv_compression_writeback_event is None:
self._kv_compression_writeback_event = torch.cuda.Event(
enable_timing=False)
# Ensure temporaries live until STORE_STREAM finishes.
key.record_stream(store_stream)
value.record_stream(store_stream)
dst_rewrite.record_stream(store_stream)
with torch.cuda.stream(store_stream):
_writeback(dst_rewrite)
self._kv_compression_writeback_event.record()
self._kv_compression_writeback_pending = True
else:
_writeback(dst_rewrite)
return output return output
assert not use_local_attn, ( assert not use_local_attn, (
...@@ -733,6 +947,228 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -733,6 +947,228 @@ class FlashAttentionImpl(AttentionImpl):
return output return output
def _snapkv_like_token_scores(
*,
query: torch.Tensor, # [T, Hq, D]
key: torch.Tensor, # [T, Hkv, D]
query_start_loc: torch.Tensor, # [B+1]
window: Union[int, torch.Tensor],
sm_scale: float,
) -> torch.Tensor:
"""Compute token-shared SnapKV-like scores for a packed varlen batch.
Scores are computed as the attention mass from the last `window` query
tokens to the earlier keys within the same scheduled segment (per request),
summed across KV heads.
Prefers a Triton implementation when available; falls back to a (slower)
PyTorch reference implementation otherwise.
"""
global _DISABLE_SNAPKV_TRITON
device = query.device
T, Hq, D = query.shape
Hkv = key.shape[1]
if Hq % Hkv != 0:
raise ValueError("Query heads must be a multiple of KV heads.")
# NOTE: Triton SnapKV scoring on ROCm is experimental. It is enabled by
# default (uses a ROCm-safe kernel); set
# VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM=0 to force the PyTorch
# reference implementation.
if (HAS_TRITON and not _DISABLE_SNAPKV_TRITON and device.type == "cuda"
and (not current_platform.is_rocm()
or envs.VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM)
and query.stride(-1) == 1 and key.stride(-1) == 1):
try:
from vllm.v1.attention.kv_compression.snapkv_triton import (
query_aware_key_scores)
w = int(window) if isinstance(window, int) else window
scores_per_head = query_aware_key_scores(
q=query,
k=key,
cu_seqlens_q=query_start_loc,
cu_seqlens_k=query_start_loc,
w=w,
sm_scale=float(sm_scale),
pool=False,
protect_last=False,
normalize=False,
)
token_scores = scores_per_head.sum(dim=1)
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(token_scores)
except Exception as e:
_DISABLE_SNAPKV_TRITON = True
logger.warning(
"Triton SnapKV scoring failed; falling back to PyTorch. "
"Error: %s", e)
group = Hq // Hkv
# Read boundaries on host (small tensor).
qsl = query_start_loc.tolist()
B = len(qsl) - 1
wsl = None
if not isinstance(window, int):
if int(window.numel()) != B:
raise ValueError("window must be a scalar int or have shape [B].")
wsl = window.to(device="cpu", dtype=torch.int64).tolist()
scores = torch.zeros((T, ), device=device, dtype=torch.float32)
for b in range(B):
s = int(qsl[b])
e = int(qsl[b + 1])
L = e - s
if L <= 0:
continue
win_b = int(window) if wsl is None else int(wsl[b])
if win_b <= 0:
continue
win = min(win_b, L)
k_eff_end = L - win
if k_eff_end <= 0:
continue
q_win = query[e - win:e] # [win, Hq, D]
# Aggregate query heads to KV heads (token-shared selection).
q_win = q_win.reshape(win, Hkv, group, D).mean(dim=2) # [win, Hkv, D]
k_eff = key[s:s + k_eff_end] # [k_eff_end, Hkv, D]
qh = q_win.permute(1, 0, 2).to(torch.float32) # [Hkv, win, D]
kh = k_eff.permute(1, 0, 2).to(torch.float32) # [Hkv, k_eff_end, D]
logits = torch.matmul(qh, kh.transpose(1, 2)) * sm_scale # [Hkv, win, K]
probs = torch.softmax(logits, dim=-1)
# Sum over (heads, window queries) -> per-key token score.
scores[s:s + k_eff_end] = probs.sum(dim=1).sum(dim=0)
# Aggregate across tensor-parallel ranks so every rank selects the same
# token indices.
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(scores)
def _topk_kv_compact_slot_mapping(
*,
token_scores: Optional[torch.Tensor], # [T] float32
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
query_start_loc: torch.Tensor, # [B+1]
seq_lens: torch.Tensor, # [B] int32
block_table: torch.Tensor, # [B, max_blocks]
block_size: int,
max_query_len: Optional[int] = None,
topk_budget_max: Optional[int] = None,
) -> torch.Tensor:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device = must_keep.device
T = int(must_keep.numel())
B = int(topk_budget.numel())
dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
if T == 0 or B == 0:
return dst_slots
# Per-request segment boundaries in the packed [T] layout.
# NOTE: `query_start_loc` is already sliced to [B+1] by the model runner.
starts = query_start_loc[:B].to(torch.long)
ends = query_start_loc[1:B + 1].to(torch.long)
lengths = ends - starts # [B]
if lengths.numel() == 0:
return dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max = int(max_query_len) if max_query_len is not None else int(
lengths.max().item())
if L_max <= 0:
return dst_slots
# Map each token to its (request, offset-within-request) coordinate.
token_idx = torch.arange(T, device=device, dtype=torch.long)
# For monotonic `ends` (cu_seqlens), this returns the request id for each
# token in the packed layout.
# Use right=True so that idx==ends[b] maps to the *next* request (b+1),
# i.e., request segments are [start, end) in the packed layout.
req_ids = torch.bucketize(token_idx, ends, right=True) # [T]
start_per_token = starts.index_select(0, req_ids)
pos_in_req = token_idx - start_per_token # [T] in [0, L_b)
# Clamp the per-request top-k budget to the number of candidate tokens
# (excluding must_keep).
must_keep_counts = torch.zeros(B, device=device, dtype=torch.long)
must_keep_counts.scatter_add_(0, req_ids, must_keep.to(torch.long))
cand_counts = (lengths.to(torch.long) - must_keep_counts).clamp_min(0)
k_eff = torch.minimum(topk_budget.to(torch.long).clamp_min(0), cand_counts)
# Prefer an upper bound from CPU (piecewise graph), to avoid sync.
if topk_budget_max is not None:
k_max = min(int(topk_budget_max), L_max)
else:
k_max = int(k_eff.max().item())
# Build a padded [B, L_max] score matrix for a single batched Top-K call.
# Must-keep and padding positions are set to -inf to avoid selection.
keep_mask = must_keep.clone()
if k_max > 0:
if token_scores is None:
raise ValueError("token_scores must be provided when k_max > 0.")
masked_scores = token_scores.to(dtype=torch.float32).masked_fill(
must_keep, float("-inf"))
scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
linear = req_ids * L_max + pos_in_req
scores_flat[linear] = masked_scores
scores = scores_flat.view(B, L_max)
topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max]
# Select only the first k_eff[b] entries for each request b.
col_mask = torch.arange(k_max, device=device).unsqueeze(
0) < k_eff.unsqueeze(1) # [B, k_max]
# Avoid host-side synchronization from dynamic indexing. Instead, mark
# selected tokens via a fixed-size scatter-add.
global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B, k_max]
flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
flat_val = col_mask.reshape(-1).to(torch.int32)
tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
tmp.scatter_add_(0, flat_idx, flat_val)
keep_mask |= tmp > 0
# Compute segment-local ranks (0..kept-1) for kept tokens, preserving token
# order within each request, without dynamic indexing (graph-friendly).
keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T]
start_minus_1 = (starts - 1).clamp_min(0)
prefix_before_all = keep_prefix.index_select(0, start_minus_1.to(torch.long))
prefix_before = torch.where(starts > 0, prefix_before_all,
torch.zeros_like(prefix_before_all)) # [B]
prefix_before_per_token = prefix_before.index_select(0, req_ids) # [T]
local_rank = keep_prefix - prefix_before_per_token - 1 # [T]
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv = (seq_lens[:B].to(torch.long) - lengths.to(torch.long)).clamp_min(0)
base_kv_per_token = base_kv.index_select(0, req_ids) # [T]
dest_pos = base_kv_per_token + local_rank # [T]
dest_block_idx = dest_pos // block_size
dest_off = dest_pos - dest_block_idx * block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks = int(block_table.shape[1])
dest_block_idx_safe = dest_block_idx.clamp_(0, max_blocks - 1).to(torch.long)
block_nums = block_table[req_ids, dest_block_idx_safe]
dest_slot = block_nums.to(torch.long) * block_size + dest_off
dst_slots = torch.where(keep_mask, dest_slot.to(torch.int64), dst_slots)
return dst_slots
def use_cascade_attention( def use_cascade_attention(
common_prefix_len: int, common_prefix_len: int,
query_lens: np.ndarray, query_lens: np.ndarray,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
from typing import Optional, Union
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import triton
import triton.language as tl
if HAS_TRITON:
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk},
num_warps=num_warps,
num_stages=num_stages,
) for bq in [32, 64] for bk in [32, 64] for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
)
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
):
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0)
s = tl.dot(q_rows, k_blk.T) * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
@triton.jit
def _lse_and_store_logits_kernel_rocm_safe(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""ROCm-safe variant of `_lse_and_store_logits_kernel`.
On some ROCm + Triton (HIP) stacks we have observed memory corruption
from the tl.dot-based implementation. This variant avoids `tl.dot` and
instead computes dot-products via explicit outer-product accumulation.
"""
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
# Accumulate s = Q @ K^T in fp32 via outer products.
s = tl.zeros([BLOCK_Q, BLOCK_K], dtype=tl.float32)
for ds in tl.static_range(0, D, BLOCK_D):
offs_d = ds + tl.arange(0, BLOCK_D)
dmask = offs_d < D
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_chunk = tl.load(
q_ptrs,
mask=row_mask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BQ, BD]
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_chunk = tl.load(
k_ptrs,
mask=kmask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BK, BD]
s += tl.sum(q_chunk[:, None, :] * k_chunk[None, :, :], axis=2)
s = s * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
m,
mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
S,
mask=row_mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64] for bk in [32, 64, 128]
],
key=["HK", "HQ"],
)
@triton.jit
def _scores_from_logits_kernel(
cu_k,
w_b,
in_m,
in_S,
LOGITS,
OUT,
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
DO_POOL: tl.constexpr,
KPOOL: tl.constexpr,
PROTECT_LAST: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
scores = tl.zeros([BLOCK_K], dtype=tl.float32)
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
m_ptr = (in_m + b * STRIDE_M_B + hk * STRIDE_M_H +
row0 * STRIDE_M_R)
S_ptr = (in_S + b * STRIDE_S_B + hk * STRIDE_S_H +
row0 * STRIDE_S_R)
m = tl.load(m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"))
S = tl.load(S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
mask=rmask,
other=0.0)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] *
STRIDE_LG_R)
s_T = tl.load(log_ptrs,
mask=kmask[:, None] & rmask[None, :],
other=-float("inf"))
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
scores += tl.sum(probs_T, 1)
if DO_POOL and (KPOOL > 1):
i = tl.arange(0, BLOCK_K)[:, None]
j = tl.arange(0, BLOCK_K)[None, :]
band = (j <= i) & ((i - j) < KPOOL)
band = band & kmask[None, :]
sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1)
denom = tl.sum(band, 1).to(tl.float32)
denom = tl.where(denom > 0, denom, 1.0)
scores = sums / denom
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, scores, mask=kmask)
if PROTECT_LAST:
pad_beg = k_eff_end
pad_end = k_end
if pad_end > pad_beg:
for ks in tl.range(pad_beg, pad_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < pad_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs,
tl.full([BLOCK_K], float("inf"), dtype=tl.float32),
mask=kmask)
@triton.autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT,
cu_k,
w_b,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr,
EPS: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1] int32
cu_seqlens_k: torch.Tensor, # [B+1] int32
w: Union[int, torch.Tensor], # [B] int32 or scalar
sm_scale: Optional[float] = None,
*,
pool: bool = True,
kpool: int = 5,
protect_last: bool = True,
normalize: bool = False,
) -> torch.Tensor:
"""SnapKV query-aware key scores (Triton), returns [N_k, Hk] float32."""
if not HAS_TRITON:
raise RuntimeError("Triton is not available.")
if q.device.type != "cuda" or k.device.type != "cuda":
raise RuntimeError("Triton SnapKV requires CUDA/ROCm tensors.")
if q.ndim != 3 or k.ndim != 3:
raise ValueError("q and k must be 3D tensors.")
if q.stride(-1) != 1 or k.stride(-1) != 1:
raise ValueError("Last dim must be contiguous for Triton SnapKV.")
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
if D != Dk:
raise ValueError("q and k must have the same head size.")
if (Hq % Hk) != 0:
raise ValueError("Hq must be a multiple of Hk.")
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = int(cu_seqlens_q.numel() - 1)
if B != int(cu_seqlens_k.numel() - 1):
raise ValueError("cu_seqlens_q and cu_seqlens_k must match.")
G = Hq // Hk
if isinstance(w, int):
max_w = int(w)
w = torch.full((B, ),
fill_value=max_w,
device=device,
dtype=torch.int32)
else:
if w.numel() != B:
raise ValueError("w must have shape [B].")
w = w.to(device=device, dtype=torch.int32)
max_w = int(w.max().item())
rows_max = max_w * G
if rows_max <= 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
if kpool < 1:
raise ValueError("kpool must be >= 1.")
out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
S_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
logits_buf = torch.empty((N_k, Hk, rows_max),
dtype=torch.float32,
device=device)
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(meta):
return B, Hk, triton.cdiv(rows_max, meta["BLOCK_Q"])
cu_q = cu_seqlens_q.to(device=device, dtype=torch.int32)
cu_k = cu_seqlens_k.to(device=device, dtype=torch.int32)
# NOTE: On ROCm/HIP, we prefer a dot-free kernel variant to avoid known
# correctness issues (silent memory corruption) observed with the tl.dot
# implementation on some stacks.
is_rocm = getattr(torch.version, "hip", None) is not None
if is_rocm:
_lse_and_store_logits_kernel_rocm_safe[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
BLOCK_Q=32,
BLOCK_K=32,
BLOCK_D=16,
num_warps=4,
num_stages=1,
)
else:
_lse_and_store_logits_kernel[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=rows_max,
)
_scores_from_logits_kernel[(B, Hk)](
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
out,
QUERY_GROUP_SIZE=G,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_OUT_NK=STRIDE_OUT_NK,
STRIDE_OUT_HK=STRIDE_OUT_HK,
DO_POOL=pool,
KPOOL=kpool,
PROTECT_LAST=protect_last,
)
if normalize:
_zscore_per_batch_epilogue[(B, )](
out,
cu_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
return out
...@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC): ...@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers: for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens) manager.remove_skipped_blocks(request_id, num_computed_tokens)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
Returns True if any blocks were freed.
"""
truncated = False
for manager in self.single_type_managers:
truncated = manager.truncate_to_num_tokens(request_id,
num_tokens) or truncated
return truncated
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
""" """
Get the blocks for the request. Get the blocks for the request.
......
...@@ -7,6 +7,8 @@ from typing import Optional ...@@ -7,6 +7,8 @@ from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger from vllm.logger import init_logger
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import sha256 from vllm.utils import sha256
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
...@@ -251,6 +253,14 @@ class KVCacheManager: ...@@ -251,6 +253,14 @@ class KVCacheManager:
# the new prefix caching hits # the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens + num_computed_tokens = (request.num_computed_tokens +
num_new_computed_tokens) num_new_computed_tokens)
if envs.VLLM_ENABLE_KV_COMPRESSION and not current_platform.is_tpu():
# KV compression decouples logical positions from KV cache
# positions. Allocate based on the KV cache length (plus the tokens
# scheduled for this step, which are temporarily written to cache).
num_tokens_need_slot = min(
request.num_kv_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
else:
num_tokens_need_slot = min( num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens, num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len) self.max_model_len)
...@@ -385,6 +395,14 @@ class KVCacheManager: ...@@ -385,6 +395,14 @@ class KVCacheManager:
return KVCacheBlocks( return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids() self.coordinator.get_blocks(request_id)).get_block_ids()
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort operation that may free blocks back to the pool.
Returns True if any blocks were freed.
"""
return self.coordinator.truncate_to_num_tokens(request_id, num_tokens)
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled.""" """Cache the blocks for the request, if enabled."""
if self.enable_caching: if self.enable_caching:
......
...@@ -31,6 +31,7 @@ class NewRequestData: ...@@ -31,6 +31,7 @@ class NewRequestData:
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
num_kv_tokens: int
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
@classmethod @classmethod
...@@ -49,6 +50,7 @@ class NewRequestData: ...@@ -49,6 +50,7 @@ class NewRequestData:
pooling_params=request.pooling_params, pooling_params=request.pooling_params,
block_ids=block_ids, block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens, num_computed_tokens=request.num_computed_tokens,
num_kv_tokens=request.num_kv_tokens,
lora_request=request.lora_request, lora_request=request.lora_request,
) )
...@@ -62,6 +64,7 @@ class NewRequestData: ...@@ -62,6 +64,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids}," f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens}," f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}" f"lora_request={self.lora_request}"
")") ")")
...@@ -76,6 +79,7 @@ class NewRequestData: ...@@ -76,6 +79,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids}," f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens}," f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}" f"lora_request={self.lora_request}"
")") ")")
...@@ -93,6 +97,7 @@ class CachedRequestData: ...@@ -93,6 +97,7 @@ class CachedRequestData:
new_token_ids: list[list[int]] new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]] new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int] num_computed_tokens: list[int]
num_kv_tokens: list[int]
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:
...@@ -106,6 +111,7 @@ class CachedRequestData: ...@@ -106,6 +111,7 @@ class CachedRequestData:
new_token_ids=[], new_token_ids=[],
new_block_ids=[], new_block_ids=[],
num_computed_tokens=[], num_computed_tokens=[],
num_kv_tokens=[],
) )
......
...@@ -28,12 +28,15 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy, ...@@ -28,12 +28,15 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
from vllm.v1.core.sched.utils import check_stop from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs) EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig, SlidingWindowSpec
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.kv_compression.budget import (compute_topk_budget_step,
count_prompt_must_keep_in_range)
from vllm.platforms import current_platform
from vllm import envs from vllm import envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -156,6 +159,50 @@ class Scheduler(SchedulerInterface): ...@@ -156,6 +159,50 @@ class Scheduler(SchedulerInterface):
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.full_cuda_graph = self.compilation_config.full_cuda_graph self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.use_mla = vllm_config.model_config.use_mla self.use_mla = vllm_config.model_config.use_mla
# KV compression is a GPU-only feature in this fork; ignore it on TPU.
self.kv_compression_enabled = (envs.VLLM_ENABLE_KV_COMPRESSION
and not current_platform.is_tpu())
if envs.VLLM_ENABLE_KV_COMPRESSION and current_platform.is_tpu():
logger.warning_once(
"KV compression is not supported on TPU; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1.")
if self.kv_compression_enabled:
if envs.VLLM_KV_COMPRESSION_POLICY != "topk":
raise ValueError(
"VLLM_KV_COMPRESSION_POLICY must be 'topk'.")
if any(
isinstance(group.kv_cache_spec, SlidingWindowSpec)
for group in kv_cache_config.kv_cache_groups):
raise ValueError(
"KV compression is incompatible with sliding window "
"attention.")
if self.cache_config.enable_prefix_caching:
raise ValueError(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression.")
if self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
if self.speculative_config is not None:
raise ValueError(
"KV compression is currently incompatible with "
"speculative decoding.")
if envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET < -1:
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1.")
if not (0.0 <= envs.VLLM_KV_COMPRESSION_PROMPT_RATIO <= 1.0):
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1].")
if envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW < 1:
raise ValueError(
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1.")
# Create the KV cache manager. # Create the KV cache manager.
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
...@@ -207,6 +254,8 @@ class Scheduler(SchedulerInterface): ...@@ -207,6 +254,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related. # Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# For logging. # For logging.
scheduled_timestamp = time.monotonic() scheduled_timestamp = time.monotonic()
...@@ -274,6 +323,13 @@ class Scheduler(SchedulerInterface): ...@@ -274,6 +323,13 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens - num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0) request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens == request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
...@@ -295,6 +351,7 @@ class Scheduler(SchedulerInterface): ...@@ -295,6 +351,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
...@@ -321,6 +378,10 @@ class Scheduler(SchedulerInterface): ...@@ -321,6 +378,10 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index structured_output_request_ids[request.request_id] = req_index
if request.request_id in force_replace_block_ids:
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = ( req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids()) new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
...@@ -532,6 +593,8 @@ class Scheduler(SchedulerInterface): ...@@ -532,6 +593,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -586,6 +649,7 @@ class Scheduler(SchedulerInterface): ...@@ -586,6 +649,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens, num_scheduled_tokens,
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids, req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
...@@ -645,6 +709,16 @@ class Scheduler(SchedulerInterface): ...@@ -645,6 +709,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related. # Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# Track the LoRAs in this step to respect max_loras when scheduling
# waiting requests first.
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in self.running
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# For logging. # For logging.
scheduled_timestamp = time.monotonic() scheduled_timestamp = time.monotonic()
...@@ -826,6 +900,8 @@ class Scheduler(SchedulerInterface): ...@@ -826,6 +900,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -894,6 +970,14 @@ class Scheduler(SchedulerInterface): ...@@ -894,6 +970,14 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens - num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0) request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens
== request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
...@@ -915,6 +999,7 @@ class Scheduler(SchedulerInterface): ...@@ -915,6 +999,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats: if self.log_stats:
preempted_req.record_event( preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp) EngineCoreEventType.PREEMPTED, scheduled_timestamp)
...@@ -941,6 +1026,10 @@ class Scheduler(SchedulerInterface): ...@@ -941,6 +1026,10 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index structured_output_request_ids[request.request_id] = req_index
if request.request_id in force_replace_block_ids:
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = ( req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids()) new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
...@@ -1014,6 +1103,7 @@ class Scheduler(SchedulerInterface): ...@@ -1014,6 +1103,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens, num_scheduled_tokens,
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids, req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
...@@ -1076,7 +1166,50 @@ class Scheduler(SchedulerInterface): ...@@ -1076,7 +1166,50 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items(): for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id] request = self.requests[req_id]
start_pos = request.num_computed_tokens
request.num_computed_tokens += num_scheduled_token request.num_computed_tokens += num_scheduled_token
if not self.kv_compression_enabled:
# Keep KV length in sync with logical length when compression
# is disabled (default vLLM behavior).
request.num_kv_tokens += num_scheduled_token
continue
# When KV compression is enabled, only keep a subset of prompt
# tokens. Decode tokens are always kept.
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = request.num_computed_tokens
prompt_end = request.num_prompt_tokens
# Decode token(s): keep all.
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
)
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
request.num_kv_tokens += (
kept_decode + kept_prompt_must_keep + kept_prompt_topk)
# Clear the finished request IDs. # Clear the finished request IDs.
...@@ -1091,11 +1224,16 @@ class Scheduler(SchedulerInterface): ...@@ -1091,11 +1224,16 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens: dict[str, int], num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]], spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]], req_to_new_block_ids: dict[str, tuple[list[int], ...]],
*,
force_replace_block_ids: Optional[set[str]] = None,
) -> CachedRequestData: ) -> CachedRequestData:
req_ids: list[str] = [] req_ids: list[str] = []
new_token_ids: list[list[int]] = [] new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = [] new_block_ids: list[tuple[list[int], ...]] = []
num_computed_tokens: list[int] = [] num_computed_tokens: list[int] = []
num_kv_tokens: list[int] = []
resumed_from_preemption: list[bool] = []
force_replace_block_ids = force_replace_block_ids or set()
for req in itertools.chain(running_reqs, resumed_reqs): for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id req_id = req.request_id
...@@ -1111,10 +1249,9 @@ class Scheduler(SchedulerInterface): ...@@ -1111,10 +1249,9 @@ class Scheduler(SchedulerInterface):
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id]) new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens) num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do num_kv_tokens.append(req.num_kv_tokens)
# in-place appending so that we don't need to allocate a new list. resumed_from_preemption.append(
resumed_from_preemption = [False] * len(running_reqs) (req in resumed_reqs) or (req_id in force_replace_block_ids))
resumed_from_preemption += [True] * len(resumed_reqs)
return CachedRequestData( return CachedRequestData(
req_ids=req_ids, req_ids=req_ids,
...@@ -1122,6 +1259,7 @@ class Scheduler(SchedulerInterface): ...@@ -1122,6 +1259,7 @@ class Scheduler(SchedulerInterface):
new_token_ids=new_token_ids, new_token_ids=new_token_ids,
new_block_ids=new_block_ids, new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens, num_computed_tokens=num_computed_tokens,
num_kv_tokens=num_kv_tokens,
) )
def _try_schedule_encoder_inputs( def _try_schedule_encoder_inputs(
...@@ -1567,6 +1705,7 @@ class Scheduler(SchedulerInterface): ...@@ -1567,6 +1705,7 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling. # Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
request.num_kv_tokens = num_computed_tokens
# Return that we are ready. # Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id) self.finished_recving_kv_req_ids.remove(request.request_id)
......
...@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC): ...@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC):
self.block_pool.free_blocks(ordered_blocks) self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None) self.num_cached_block.pop(request_id, None)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort optimization hook. Subclasses may override this
to free no-longer-needed blocks (e.g., after KV compaction). The default
implementation is a no-op.
"""
return False
@abstractmethod @abstractmethod
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int: num_running_requests: int) -> int:
...@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager): ...@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager):
# No need to remove blocks for full attention. # No need to remove blocks for full attention.
pass pass
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
num_tokens = max(int(num_tokens), 0)
blocks = self.req_to_blocks.get(request_id)
if not blocks:
return False
num_required_blocks = cdiv(num_tokens, self.block_size)
if num_required_blocks >= len(blocks):
return False
removed_blocks = blocks[num_required_blocks:]
del blocks[num_required_blocks:]
self.block_pool.free_blocks(reversed(removed_blocks))
if request_id in self.num_cached_block:
self.num_cached_block[request_id] = min(
self.num_cached_block[request_id], len(blocks))
return True
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int: num_running_requests: int) -> int:
blocks = self.req_to_blocks[request_id] blocks = self.req_to_blocks[request_id]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .budget import ( # noqa: F401
compute_topk_budget_step,
count_prompt_must_keep_in_range,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
def _clamp_int(value: int, lo: int, hi: int) -> int:
if value < lo:
return lo
if value > hi:
return hi
return value
def _intersection_len(a0: int, a1: int, b0: int, b1: int) -> int:
start = a0 if a0 > b0 else b0
end = a1 if a1 < b1 else b1
return max(0, end - start)
def _protected_prefix_len(prompt_len: int, protected_prefix: int) -> int:
return min(max(protected_prefix, 0), max(prompt_len, 0))
def _protected_suffix_start(prompt_len: int, protected_suffix: int) -> int:
prompt_len = max(prompt_len, 0)
suffix = min(max(protected_suffix, 0), prompt_len)
return prompt_len - suffix
def count_prompt_must_keep_in_range(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt tokens in [start_pos, end_pos) that are always kept."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
start = _clamp_int(start_pos, 0, prompt_len)
end = _clamp_int(end_pos, 0, prompt_len)
if end <= start:
return 0
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
keep_prefix = _intersection_len(start, end, 0, prefix_len)
keep_suffix = _intersection_len(start, end, suffix_start, prompt_len)
overlap = _intersection_len(start, end, suffix_start, prefix_len)
kept = keep_prefix + keep_suffix - overlap
if keep_last_token:
last = prompt_len - 1
if start <= last < end:
already_kept = (last < prefix_len) or (last >= suffix_start)
if not already_kept:
kept += 1
return kept
def _count_prompt_candidates_upto(
*,
prompt_len: int,
pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt candidates in [0, pos) eligible for Top-K selection."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
x = _clamp_int(pos, 0, prompt_len)
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
mid_end = min(x, suffix_start)
cand = max(0, mid_end - min(prefix_len, mid_end))
if keep_last_token:
last = prompt_len - 1
if prefix_len <= last < mid_end:
cand -= 1
return max(cand, 0)
def _candidate_total(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
return _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
def _candidate_keep_total(
*,
candidate_total: int,
prompt_ratio: float,
prompt_budget: int,
) -> int:
if candidate_total <= 0:
return 0
if prompt_budget >= 0:
return min(prompt_budget, candidate_total)
ratio = max(0.0, min(float(prompt_ratio), 1.0))
keep = int(math.floor(candidate_total * ratio + 0.5))
return _clamp_int(keep, 0, candidate_total)
def compute_topk_budget_step(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute how many prompt candidate tokens to select for this step.
The budget applies to the *non-protected* prompt region and is distributed
across multiple prefill steps using a prefix-proportional rule:
budget_upto(x) = floor(total_keep * candidates_upto(x) / candidates_total)
The step's budget is the delta between its end and start positions.
"""
total = _candidate_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
if total <= 0:
return 0
total_keep = _candidate_keep_total(
candidate_total=total,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
if total_keep <= 0:
return 0
cand_upto_start = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=start_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
cand_upto_end = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
step_total = max(0, cand_upto_end - cand_upto_start)
if step_total == 0:
return 0
bud_upto_start = (total_keep * cand_upto_start) // total
bud_upto_end = (total_keep * cand_upto_end) // total
step_keep = bud_upto_end - bud_upto_start
return _clamp_int(step_keep, 0, step_total)
...@@ -79,6 +79,10 @@ class Request: ...@@ -79,6 +79,10 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy() self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = [] self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0 self.num_computed_tokens = 0
# Number of tokens currently stored in the KV cache for this request.
# This can be different from `num_computed_tokens` when KV compression
# is enabled (e.g., token-shared prefill compression).
self.num_kv_tokens = 0
self.num_generated_token_ids = 0 self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt self.cache_salt: Optional[str] = cache_salt
......
...@@ -63,6 +63,11 @@ class BlockTable: ...@@ -63,6 +63,11 @@ class BlockTable:
def add_row(self, block_ids: list[int], row_idx: int) -> None: def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0 self.num_blocks_per_row[row_idx] = 0
# Keep the invariant that "unused" entries map to the null block (id=0).
# This matters when we *shrink* a request's block list (e.g. KV
# compression tail-block truncation) and later re-use freed blocks for
# other requests.
self.block_table_np[row_idx, :].fill(0)
self.append_row(block_ids, row_idx) self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None: def move_row(self, src: int, tgt: int) -> None:
......
...@@ -38,6 +38,7 @@ class CachedRequestState: ...@@ -38,6 +38,7 @@ class CachedRequestState:
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
num_kv_tokens: int
output_token_ids: list[int] output_token_ids: list[int]
spec_token_ids: list[int] = None spec_token_ids: list[int] = None
...@@ -114,6 +115,13 @@ class InputBatch: ...@@ -114,6 +115,13 @@ class InputBatch:
) )
self.num_computed_tokens_cpu = \ self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy() self.num_computed_tokens_cpu_tensor.numpy()
self.num_kv_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
# Block table. # Block table.
self.block_table = MultiGroupBlockTable( self.block_table = MultiGroupBlockTable(
...@@ -348,6 +356,7 @@ class InputBatch: ...@@ -348,6 +356,7 @@ class InputBatch:
self.num_tokens_no_spec[req_index] = request.num_tokens self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
self.block_table.add_row(request.block_ids, req_index) self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params: if sampling_params := request.sampling_params:
...@@ -504,6 +513,8 @@ class InputBatch: ...@@ -504,6 +513,8 @@ class InputBatch:
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] =\
self.num_kv_tokens_cpu[i2], self.num_kv_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\ self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1] self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\ self.top_p_cpu[i1], self.top_p_cpu[i2] =\
...@@ -602,6 +613,8 @@ class InputBatch: ...@@ -602,6 +613,8 @@ class InputBatch:
last_req_index] last_req_index]
self.num_computed_tokens_cpu[ self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index] empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.num_kv_tokens_cpu[
empty_index] = self.num_kv_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index) self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[ self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index] last_req_index]
......
...@@ -55,6 +55,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget ...@@ -55,6 +55,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec, KVCacheConfig, KVCacheSpec, MambaSpec,
SlidingWindowSpec) SlidingWindowSpec)
from vllm.v1.kv_compression.budget import compute_topk_budget_step
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
...@@ -146,6 +147,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -146,6 +147,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.attention_chunk_size = model_config.attention_chunk_size self.attention_chunk_size = model_config.attention_chunk_size
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
if envs.VLLM_ENABLE_KV_COMPRESSION:
# KV compression changes the effective KV sequence layout and
# invalidates cascade attention assumptions (common-prefix blocks).
self.cascade_attn_enabled = False
# Whether the current step needs KV compaction work (score/topk/dst).
# This is set per-step in `_prepare_inputs`.
self.kv_compression_needs_compaction: bool = False
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
...@@ -313,6 +321,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -313,6 +321,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy() self.positions_np = self.positions_cpu.numpy()
# KV positions are decoupled from logical positions when KV compression
# is enabled. We keep a separate buffer to avoid recomputing or
# overwriting `positions_np` (used for RoPE / input token lookup).
self.kv_positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.kv_positions_np = self.kv_positions_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
...@@ -323,6 +339,34 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -323,6 +339,34 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
# KV compression metadata buffers (used by the "topk" policy).
# Per-token: whether this scheduled token must be kept in KV cache.
self.kv_compression_must_keep_cpu = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_must_keep_np = self.kv_compression_must_keep_cpu.numpy()
self.kv_compression_must_keep = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device=self.device,
)
# Per-request: how many additional prompt tokens to keep among
# non-protected candidates (budget from env; selection uses scores).
self.kv_compression_topk_budget_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_topk_budget_np = self.kv_compression_topk_budget_cpu.numpy()
self.kv_compression_topk_budget = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
# Layer pairings for cross-layer KV sharing. # Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it # If an Attention layer `layer_name` is in the keys of this dict, it
...@@ -448,6 +492,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -448,6 +492,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
generator=generator, generator=generator,
block_ids=new_req_data.block_ids, block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
num_kv_tokens=new_req_data.num_kv_tokens,
output_token_ids=[], output_token_ids=[],
lora_request=new_req_data.lora_request, lora_request=new_req_data.lora_request,
) )
...@@ -497,11 +542,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -497,11 +542,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
for i, req_id in enumerate(req_data.req_ids): for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i] num_computed_tokens = req_data.num_computed_tokens[i]
num_kv_tokens = req_data.num_kv_tokens[i]
new_block_ids = req_data.new_block_ids[i] new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i] resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
req_state.num_kv_tokens = num_kv_tokens
spec_token_ids = ( spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
...@@ -545,6 +592,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -545,6 +592,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
self.input_batch.num_kv_tokens_cpu[req_index] = num_kv_tokens
if resumed_from_preemption:
self.input_batch.block_table.add_row(new_block_ids, req_index)
else:
self.input_batch.block_table.append_row(new_block_ids, req_index) self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu # For the last rank, we don't need to update the token_ids_cpu
...@@ -658,6 +709,78 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -658,6 +709,78 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
np.add(self.input_batch.num_computed_tokens_cpu[req_indices], np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange, arange,
out=positions_np) out=positions_np)
# KV positions (where the KV for each scheduled token is temporarily
# written). When KV compression is enabled, KV positions are decoupled
# from logical positions.
use_kv_compression = envs.VLLM_ENABLE_KV_COMPRESSION
if use_kv_compression:
kv_positions_np = self.kv_positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_kv_tokens_cpu[req_indices],
arange,
out=kv_positions_np)
else:
kv_positions_np = None
if use_kv_compression:
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
must_keep_np = self.kv_compression_must_keep_np[
:total_num_scheduled_tokens]
must_keep_np.fill(False)
topk_budget_np = self.kv_compression_topk_budget_np[:num_reqs]
topk_budget_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
start = 0 if req_idx == 0 else int(cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx])
assert end - start == qlen
base_pos = int(
self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
pos = base_pos + np.arange(qlen, dtype=np.int64)
prompt_mask = pos < prompt_len
# Decode tokens are always kept.
must_keep = ~prompt_mask
if np.any(prompt_mask):
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
must_keep_np[start:end] = must_keep
# Decode-only fast path: if all scheduled tokens are unconditionally
# kept and there is no Top-K budget, KV compaction is a no-op and we
# can skip score/topk/dst entirely in the attention backend.
self.kv_compression_needs_compaction = (not must_keep_np.all()) or (
topk_budget_np > 0).any()
else:
self.kv_compression_needs_compaction = False
# Calculate M-RoPE positions. # Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
...@@ -685,6 +808,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -685,6 +808,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
block_size = kv_cache_group_spec.kv_cache_spec.block_size block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = self.input_batch.block_table[ block_table: BlockTable = self.input_batch.block_table[
kv_cache_group_id] kv_cache_group_id]
slot_positions_np = (kv_positions_np
if use_kv_compression else positions_np)
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2. # where K is the max_num_blocks_per_req and the block size is 2.
...@@ -693,11 +818,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -693,11 +818,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# block_size. # block_size.
block_table_indices = ( block_table_indices = (
req_indices * block_table.max_num_blocks_per_req + req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size) slot_positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor() block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten( block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy() )[block_table_indices].numpy()
block_offsets = positions_np % block_size block_offsets = slot_positions_np % block_size
np.add( np.add(
block_numbers * block_size, block_numbers * block_size,
block_offsets, block_offsets,
...@@ -707,6 +832,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -707,6 +832,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_np[0] = 0 self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
if use_kv_compression:
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_kv_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
else:
self.seq_lens_np[:num_reqs] = ( self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] + self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens) num_scheduled_tokens)
...@@ -729,6 +859,15 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -729,6 +859,15 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True) non_blocking=True)
if use_kv_compression:
self.kv_compression_must_keep[:total_num_scheduled_tokens].copy_(
self.kv_compression_must_keep_cpu[:total_num_scheduled_tokens],
non_blocking=True,
)
self.kv_compression_topk_budget[:num_reqs].copy_(
self.kv_compression_topk_budget_cpu[:num_reqs],
non_blocking=True,
)
# Fill unused with -1. Needed for reshape_and_cache # Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0) self.seq_lens[num_reqs:].fill_(0)
...@@ -2532,6 +2671,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2532,6 +2671,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
assert len(self.attn_backends) == 0 and len( assert len(self.attn_backends) == 0 and len(
self.attn_metadata_builders self.attn_metadata_builders
) == 0, "Attention backends are already initialized" ) == 0, "Attention backends are already initialized"
if envs.VLLM_ENABLE_KV_COMPRESSION and self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
for i, kv_cache_group_spec in enumerate( for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups): kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec kv_cache_spec = kv_cache_group_spec.kv_cache_spec
...@@ -2555,7 +2698,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -2555,7 +2698,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
raise NotImplementedError( raise NotImplementedError(
"Non-Attention backend is not supported by V1 " "Non-Attention backend is not supported by V1 "
"GPUModelRunner.") "GPUModelRunner.")
if (envs.VLLM_ENABLE_KV_COMPRESSION
and attn_backend_i.get_name() != "FLASH_ATTN_VLLM_V1"):
raise ValueError(
"KV compression currently requires "
"VLLM_ATTENTION_BACKEND=FLASH_ATTN_VLLM_V1.")
elif isinstance(kv_cache_spec, MambaSpec): elif isinstance(kv_cache_spec, MambaSpec):
if envs.VLLM_ENABLE_KV_COMPRESSION:
raise ValueError(
"KV compression is currently only supported for "
"Transformer attention layers.")
attn_backend_i = Mamba2AttentionBackend attn_backend_i = Mamba2AttentionBackend
else: else:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment