Commit 3adc766e authored by laibao's avatar laibao
Browse files

refactor: 抽离 flash_attn 的 KV compression 逻辑到 vllm/v1/kv_compression

parent 9db5ff3b
...@@ -2,13 +2,12 @@ ...@@ -2,13 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType,
...@@ -33,18 +32,21 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config ...@@ -33,18 +32,21 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches) make_local_attention_virtual_batches)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_compression.flash_attn_hooks import (
maybe_compact_kv_cache_flash_attn,
maybe_compute_prompt_end_payload_flash_attn,
)
from vllm.v1.kv_compression.metadata import build_kv_compression_attn_metadata
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
_DISABLE_SNAPKV_TRITON: bool = False
# NOTE(woosuk): This is an arbitrary number. Tune it if needed. # NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
...@@ -280,37 +282,11 @@ class FlashAttentionMetadataBuilder( ...@@ -280,37 +282,11 @@ class FlashAttentionMetadataBuilder(
block_table.slot_mapping[num_actual_tokens:].fill_(-1) block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens] slot_mapping = block_table.slot_mapping[:num_actual_tokens]
kv_meta = build_kv_compression_attn_metadata(
kv_compression_must_keep = None runner=self.runner,
kv_compression_topk_budget = None num_reqs=num_reqs,
kv_compression_topk_budget_max: Optional[int] = None num_actual_tokens=num_actual_tokens,
kv_compression_prompt_end = None )
kv_compression_prompt_lens = None
kv_compression_prompt_topk_keep = None
kv_compression_prompt_topk_keep_max: Optional[int] = None
if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.runner.kv_compression_needs_compaction):
kv_compression_must_keep = self.runner.kv_compression_must_keep[:
num_actual_tokens]
kv_compression_topk_budget = self.runner.kv_compression_topk_budget[:
num_reqs]
# Avoid device->host sync by reading from the CPU staging buffer.
if num_reqs > 0:
kv_compression_topk_budget_max = int(
self.runner.kv_compression_topk_budget_np[:num_reqs].max())
else:
kv_compression_topk_budget_max = 0
elif (envs.VLLM_ENABLE_KV_COMPRESSION
and self.runner.scheduler_config.chunked_prefill_enabled):
# Scheme 3: compute global prompt indices only on the last prefill
# chunk (per request), and perform the actual cache compaction
# before the first decode step.
if num_reqs > 0 and self.runner.kv_compression_prompt_end_np[:num_reqs].any():
kv_compression_prompt_end = self.runner.kv_compression_prompt_end[:num_reqs]
kv_compression_prompt_lens = self.runner.kv_compression_prompt_lens[:num_reqs]
kv_compression_prompt_topk_keep = self.runner.kv_compression_prompt_topk_keep[:num_reqs]
kv_compression_prompt_topk_keep_max = int(
self.runner.kv_compression_prompt_topk_keep_max or 0)
if self.aot_sliding_window is None: if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1) self.aot_sliding_window = (-1, -1)
...@@ -470,13 +446,13 @@ class FlashAttentionMetadataBuilder( ...@@ -470,13 +446,13 @@ class FlashAttentionMetadataBuilder(
cu_prefix_query_lens=cu_prefix_query_lens, cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens, prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens, suffix_kv_lens=suffix_kv_lens,
kv_compression_must_keep=kv_compression_must_keep, kv_compression_must_keep=kv_meta.must_keep,
kv_compression_topk_budget=kv_compression_topk_budget, kv_compression_topk_budget=kv_meta.topk_budget,
kv_compression_topk_budget_max=kv_compression_topk_budget_max, kv_compression_topk_budget_max=kv_meta.topk_budget_max,
kv_compression_prompt_end=kv_compression_prompt_end, kv_compression_prompt_end=kv_meta.prompt_end,
kv_compression_prompt_lens=kv_compression_prompt_lens, kv_compression_prompt_lens=kv_meta.prompt_lens,
kv_compression_prompt_topk_keep=kv_compression_prompt_topk_keep, kv_compression_prompt_topk_keep=kv_meta.prompt_topk_keep,
kv_compression_prompt_topk_keep_max=kv_compression_prompt_topk_keep_max, kv_compression_prompt_topk_keep_max=kv_meta.prompt_topk_keep_max,
local_attn_metadata=local_attn_metadata, local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits, max_num_splits=max_num_splits,
...@@ -651,32 +627,16 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -651,32 +627,16 @@ class FlashAttentionImpl(AttentionImpl):
layer._q_scale) layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size)) query = query.reshape((num_tokens, num_heads, head_size))
# Scheme 3 (chunked prefill): on the last prompt chunk, compute global if envs.VLLM_ENABLE_KV_COMPRESSION:
# prompt indices (score/topk) and cache them in the forward context for maybe_compute_prompt_end_payload_flash_attn(
# the model runner to consume before the first decode step. kv_sharing_target_layer_name=self.kv_sharing_target_layer_name,
if (envs.VLLM_ENABLE_KV_COMPRESSION query=query,
and self.kv_sharing_target_layer_name is None num_actual_tokens=num_actual_tokens,
and attn_metadata.kv_compression_prompt_end is not None key_cache=key_cache,
and attn_metadata.kv_compression_prompt_lens is not None cache_block_size=cache_block_size,
and attn_metadata.kv_compression_prompt_topk_keep is not None): attn_metadata=attn_metadata,
forward_context = get_forward_context() sm_scale=self.scale,
payload = getattr(forward_context, "_kv_compression_prompt_payload", )
None)
if payload is None:
payload = _compute_prompt_end_indices(
query=query[:num_actual_tokens],
key_cache=key_cache,
query_start_loc=attn_metadata.query_start_loc,
block_table=attn_metadata.block_table,
prompt_end=attn_metadata.kv_compression_prompt_end,
prompt_lens=attn_metadata.kv_compression_prompt_lens,
topk_keep=attn_metadata.kv_compression_prompt_topk_keep,
topk_keep_max=attn_metadata.kv_compression_prompt_topk_keep_max,
sm_scale=self.scale,
)
if payload is not None:
setattr(forward_context, "_kv_compression_prompt_payload",
payload)
# Compute attention and update output up to `num_actual_tokens`. # Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \ use_local_attn = \
...@@ -758,127 +718,24 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -758,127 +718,24 @@ class FlashAttentionImpl(AttentionImpl):
# Optional KV compaction pass for token-shared KV compression. # Optional KV compaction pass for token-shared KV compression.
# This rewrites a selected subset of newly written KV entries into # This rewrites a selected subset of newly written KV entries into
# a packed layout for the next step. # a packed layout for the next step.
if (envs.VLLM_ENABLE_KV_COMPRESSION if envs.VLLM_ENABLE_KV_COMPRESSION:
and self.kv_sharing_target_layer_name is None): maybe_compact_kv_cache_flash_attn(
dst = None kv_sharing_target_layer_name=self.kv_sharing_target_layer_name,
if (attn_metadata.kv_compression_must_keep is not None layer=layer,
and attn_metadata.kv_compression_topk_budget query=query,
is not None): key=key,
forward_context = get_forward_context() value=value,
per_layer_topk = envs.VLLM_KV_COMPRESSION_TOPK_PER_LAYER key_cache=key_cache,
if per_layer_topk: value_cache=value_cache,
layer_name = getattr(layer, "layer_name", None) num_actual_tokens=num_actual_tokens,
if layer_name is None: cache_block_size=cache_block_size,
layer_name = str(id(layer)) attn_metadata=attn_metadata,
dst_by_layer = getattr( sm_scale=self.scale,
forward_context, "_kv_compression_compact_slots_by_layer", kv_cache_dtype=self.kv_cache_dtype,
None) reshape_and_cache=(reshape_and_cache_cuda
if dst_by_layer is None: if current_platform.is_rocm() else
dst_by_layer = {} reshape_and_cache_flash),
setattr( )
forward_context,
"_kv_compression_compact_slots_by_layer",
dst_by_layer,
)
dst = dst_by_layer.get(layer_name)
else:
dst = getattr(forward_context,
"_kv_compression_compact_slots", None)
if dst is None:
topk_budget = attn_metadata.kv_compression_topk_budget
token_scores: Optional[torch.Tensor] = None
# If there is no Top-K budget for any request in this
# step, selection does not depend on token scores.
# Skipping SnapKV scoring avoids unnecessary compute.
topk_budget_max = int(
attn_metadata.kv_compression_topk_budget_max or 0)
if topk_budget_max > 0:
# Mixed batch optimization: avoid scoring requests
# with a zero Top-K budget by setting their
# per-request window to 0 (kernel early-return).
window = int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW)
w = torch.where(
topk_budget > 0,
torch.full_like(topk_budget, window),
torch.zeros_like(topk_budget),
)
token_scores = _snapkv_like_token_scores(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
query_start_loc=attn_metadata.query_start_loc,
window=w,
sm_scale=self.scale,
)
dst = _topk_kv_compact_slot_mapping(
token_scores=token_scores,
must_keep=attn_metadata.kv_compression_must_keep,
topk_budget=topk_budget,
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
block_table=attn_metadata.block_table,
block_size=cache_block_size,
max_query_len=attn_metadata.max_query_len,
topk_budget_max=topk_budget_max,
)
if per_layer_topk:
dst_by_layer[layer_name] = dst
else:
setattr(forward_context,
"_kv_compression_compact_slots", dst)
if dst is not None:
src = attn_metadata.slot_mapping
rewrite_mask = (dst >= 0) & (dst != src)
# Avoid host-side synchronization (`torch.any(...)`) and
# dynamic boolean-indexing gathers. Instead, construct a
# per-token destination mapping where non-rewrite tokens
# are marked as -1, which the cache kernels treat as
# padding and skip.
dst_rewrite = torch.where(rewrite_mask, dst, -1)
def _writeback(dst_mapping: torch.Tensor) -> None:
if not current_platform.is_rocm():
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
if (envs.VLLM_USE_OPT_RESHAPE_AND_CACHE
and key.dtype == value.dtype
and key.dtype == torch.float16):
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
from vllm.attention.utils.fa_utils import (
reshape_and_cache_cuda)
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
_writeback(dst_rewrite)
return output return output
assert not use_local_attn, ( assert not use_local_attn, (
...@@ -937,251 +794,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -937,251 +794,6 @@ class FlashAttentionImpl(AttentionImpl):
return output return output
def _prompt_end_topk_keep_indices(
*,
token_scores: torch.Tensor, # [T] float32
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32 (candidates only)
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
topk_keep_max: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device = token_scores.device
B = int(prompt_lens.numel())
if B == 0:
empty = torch.empty((0, 0), device=device, dtype=torch.int32)
return empty, torch.empty((0, ), device=device, dtype=torch.int32)
prompt_lens_i64 = prompt_lens.to(torch.long)
cu = torch.zeros((B + 1, ), device=device, dtype=torch.long)
cu[1:] = torch.cumsum(prompt_lens_i64, dim=0)
starts = cu[:B]
ends = cu[1:]
T = int(token_scores.numel())
if T == 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, torch.zeros((B, ), device=device, dtype=torch.int32)
token_idx = torch.arange(T, device=device, dtype=torch.long)
req_ids = torch.bucketize(token_idx, ends, right=True) # [T]
start_per_token = starts.index_select(0, req_ids)
pos_in_req = token_idx - start_per_token # [T]
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_prefix, 0))
suffix = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_suffix, 0))
suffix_start = (prompt_lens_i64 - suffix).clamp_min(0)
prefix_len_t = prefix_len.index_select(0, req_ids)
suffix_start_t = suffix_start.index_select(0, req_ids)
must_keep = (pos_in_req < prefix_len_t) | (pos_in_req >= suffix_start_t)
if keep_last_token:
last = (prompt_lens_i64 - 1).clamp_min(0)
last_t = last.index_select(0, req_ids)
must_keep |= pos_in_req == last_t
cand_counts = torch.zeros((B, ), device=device, dtype=torch.long)
cand_counts.scatter_add_(0, req_ids, (~must_keep).to(torch.long))
k_eff = torch.minimum(topk_keep.to(torch.long).clamp_min(0), cand_counts)
# CPU-known bound avoids a device->host sync; clamp for safety.
if topk_keep_max is None:
k_max = int(k_eff.max().item())
else:
k_max = int(topk_keep_max)
if k_max < 0:
k_max = 0
keep_mask = must_keep.clone()
if k_max > 0:
L_max = int(prompt_lens_i64.max().item())
masked_scores = token_scores.to(torch.float32).masked_fill(
must_keep, float("-inf"))
scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
linear = req_ids * L_max + pos_in_req
scores_flat[linear] = masked_scores
scores = scores_flat.view(B, L_max)
topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max]
col_mask = (torch.arange(k_max, device=device).unsqueeze(0) <
k_eff.unsqueeze(1))
global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B,k_max]
flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
flat_val = col_mask.reshape(-1).to(torch.int32)
tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
tmp.scatter_add_(0, flat_idx, flat_val)
keep_mask |= tmp > 0
keep_len = torch.zeros((B, ), device=device, dtype=torch.long)
keep_len.scatter_add_(0, req_ids, keep_mask.to(torch.long))
keep_max_len = int(keep_len.max().item()) if B > 0 else 0
if keep_max_len <= 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, keep_len.to(torch.int32)
# Stable, order-preserving index list using segment-local ranks.
keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T]
start_minus_1 = (starts - 1).clamp_min(0)
prefix_before_all = keep_prefix.index_select(0, start_minus_1)
prefix_before = torch.where(starts > 0, prefix_before_all,
torch.zeros_like(prefix_before_all))
prefix_before_t = prefix_before.index_select(0, req_ids)
local_rank = keep_prefix - prefix_before_t - 1 # [T]
idx_sorted = torch.zeros((B, keep_max_len), device=device, dtype=torch.int32)
lin_out = (req_ids * keep_max_len + local_rank).masked_select(keep_mask)
vals = pos_in_req.to(torch.int32).masked_select(keep_mask)
idx_sorted.view(-1).scatter_(0, lin_out, vals)
return idx_sorted, keep_len.to(torch.int32)
def _compute_prompt_end_indices(
*,
query: torch.Tensor, # [T, Hq, D] scheduled tokens for this step
key_cache: torch.Tensor, # layer KV cache view (platform-dependent)
query_start_loc: torch.Tensor, # [B+1] int32
block_table: torch.Tensor, # [B, max_blocks] int32
prompt_end: torch.Tensor, # [B] bool
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32
topk_keep_max: Optional[int],
sm_scale: float,
) -> Optional[dict[str, torch.Tensor]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device = query.device
if prompt_end.numel() == 0:
return None
sel = torch.nonzero(prompt_end, as_tuple=False).flatten()
if int(sel.numel()) == 0:
return None
window = int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW)
keep_last = bool(envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN)
protected_prefix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX)
protected_suffix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX)
# Build packed Q window (last `window` queries per selected request).
sel_list = sel.to(device="cpu", dtype=torch.int64).tolist()
qsl = query_start_loc.to(device="cpu", dtype=torch.int64).tolist()
q_chunks = []
cu_q = [0]
w_list = []
for b in sel_list:
s = int(qsl[b])
e = int(qsl[b + 1])
q_len = max(0, e - s)
win = min(window, q_len)
w_list.append(int(win))
if win > 0:
q_chunks.append(query[e - win:e])
cu_q.append(cu_q[-1] + int(win))
if cu_q[-1] <= 0:
return None
q_packed = torch.cat(q_chunks, dim=0) if q_chunks else query[:0]
cu_seqlens_q = torch.tensor(cu_q, device=device, dtype=torch.int32)
w = torch.tensor(w_list, device=device, dtype=torch.int32)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel = prompt_lens.index_select(0, sel).to(torch.int32)
topk_keep_sel = topk_keep.index_select(0, sel).to(torch.int32)
cu_seqlens_k = torch.zeros((int(prompt_lens_sel.numel()) + 1, ),
device=device,
dtype=torch.int32)
if int(prompt_lens_sel.numel()) > 0:
cu_seqlens_k[1:] = torch.cumsum(prompt_lens_sel, dim=0)
block_table_sel = block_table.index_select(0, sel).to(torch.int32)
if not current_platform.is_rocm():
# CUDA cache view: [num_blocks, block_size, H, D] -> [num_blocks, H, block_size, D]
key_cache_view = key_cache.permute(0, 2, 1, 3)
else:
key_cache_view = key_cache
from vllm.v1.attention.kv_compression.kv_cache_triton import (
gather_k_to_packed_triton)
k_packed = gather_k_to_packed_triton(
key_cache_view,
block_table_sel,
prompt_lens_sel,
cu_seqlens_k,
)
# SnapKV Triton scoring (token-shared via sum over KV heads).
from vllm.v1.attention.kv_compression.snapkv_triton import (
query_aware_key_scores)
try:
scores_per_head = query_aware_key_scores(
q=q_packed,
k=k_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
w=w,
sm_scale=float(sm_scale),
pool=False,
protect_last=False,
normalize=False,
)
token_scores = scores_per_head.sum(dim=1)
except Exception:
# Fallback: PyTorch reference scoring (slow but correctness-oriented).
Hq = q_packed.shape[1]
Hk = k_packed.shape[1]
D = q_packed.shape[2]
if Hq % Hk != 0:
raise
group = Hq // Hk
token_scores = torch.zeros((k_packed.shape[0], ),
device=device,
dtype=torch.float32)
for i in range(len(sel_list)):
qs = int(cu_q[i])
qe = int(cu_q[i + 1])
ks = int(cu_seqlens_k[i].item())
ke = int(cu_seqlens_k[i + 1].item())
if qe <= qs or ke <= ks:
continue
q_win = q_packed[qs:qe] # [win, Hq, D]
q_win = q_win.reshape(q_win.shape[0], Hk, group, D).mean(dim=2)
k_all = k_packed[ks:ke]
qh = q_win.permute(1, 0, 2).to(torch.float32)
kh = k_all.permute(1, 0, 2).to(torch.float32)
logits = torch.matmul(qh, kh.transpose(1, 2)) * float(sm_scale)
probs = torch.softmax(logits, dim=-1)
token_scores[ks:ke] = probs.sum(dim=1).sum(dim=0)
from vllm.distributed.parallel_state import get_tp_group
token_scores = get_tp_group().all_reduce(token_scores)
idx_sorted, keep_len = _prompt_end_topk_keep_indices(
token_scores=token_scores,
prompt_lens=prompt_lens_sel,
topk_keep=topk_keep_sel,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
topk_keep_max=topk_keep_max,
)
return {
"req_indices": sel.to(torch.int32),
"idx_sorted": idx_sorted, # [B_sel, K_max] int32
"keep_len": keep_len, # [B_sel] int32
"prompt_lens": prompt_lens_sel, # [B_sel] int32
}
def use_cascade_attention( def use_cascade_attention(
common_prefix_len: int, common_prefix_len: int,
query_lens: np.ndarray, query_lens: np.ndarray,
...@@ -1393,225 +1005,3 @@ def cascade_attention( ...@@ -1393,225 +1005,3 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse) suffix_lse)
def _snapkv_like_token_scores(
*,
query: torch.Tensor, # [T, Hq, D]
key: torch.Tensor, # [T, Hkv, D]
query_start_loc: torch.Tensor, # [B+1]
window: Union[int, torch.Tensor],
sm_scale: float,
) -> torch.Tensor:
"""Compute token-shared SnapKV-like scores for a packed varlen batch.
Scores are computed as the attention mass from the last `window` query
tokens to the earlier keys within the same scheduled segment (per request),
summed across KV heads.
Prefers a Triton implementation when available; falls back to a (slower)
PyTorch reference implementation otherwise.
"""
global _DISABLE_SNAPKV_TRITON
device = query.device
T, Hq, D = query.shape
Hkv = key.shape[1]
if Hq % Hkv != 0:
raise ValueError("Query heads must be a multiple of KV heads.")
# NOTE: Triton SnapKV scoring on ROCm is experimental. It is enabled by
# default (uses a ROCm-safe kernel); set
# VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM=0 to force the PyTorch
# reference implementation.
if (HAS_TRITON and not _DISABLE_SNAPKV_TRITON and device.type == "cuda"
and (not current_platform.is_rocm()
or envs.VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM)
and query.stride(-1) == 1 and key.stride(-1) == 1):
try:
from vllm.v1.attention.kv_compression.snapkv_triton import (
query_aware_key_scores)
w = int(window) if isinstance(window, int) else window
scores_per_head = query_aware_key_scores(
q=query,
k=key,
cu_seqlens_q=query_start_loc,
cu_seqlens_k=query_start_loc,
w=w,
sm_scale=float(sm_scale),
pool=False,
protect_last=False,
normalize=False,
)
token_scores = scores_per_head.sum(dim=1)
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(token_scores)
except Exception as e:
_DISABLE_SNAPKV_TRITON = True
logger.warning(
"Triton SnapKV scoring failed; falling back to PyTorch. "
"Error: %s", e)
group = Hq // Hkv
# Read boundaries on host (small tensor).
qsl = query_start_loc.tolist()
B = len(qsl) - 1
wsl = None
if not isinstance(window, int):
if int(window.numel()) != B:
raise ValueError("window must be a scalar int or have shape [B].")
wsl = window.to(device="cpu", dtype=torch.int64).tolist()
scores = torch.zeros((T, ), device=device, dtype=torch.float32)
for b in range(B):
s = int(qsl[b])
e = int(qsl[b + 1])
L = e - s
if L <= 0:
continue
win_b = int(window) if wsl is None else int(wsl[b])
if win_b <= 0:
continue
win = min(win_b, L)
k_eff_end = L - win
if k_eff_end <= 0:
continue
q_win = query[e - win:e] # [win, Hq, D]
# Aggregate query heads to KV heads (token-shared selection).
q_win = q_win.reshape(win, Hkv, group, D).mean(dim=2) # [win, Hkv, D]
k_eff = key[s:s + k_eff_end] # [k_eff_end, Hkv, D]
qh = q_win.permute(1, 0, 2).to(torch.float32) # [Hkv, win, D]
kh = k_eff.permute(1, 0, 2).to(torch.float32) # [Hkv, k_eff_end, D]
logits = torch.matmul(qh, kh.transpose(1, 2)) * sm_scale # [Hkv, win, K]
probs = torch.softmax(logits, dim=-1)
# Sum over (heads, window queries) -> per-key token score.
scores[s:s + k_eff_end] = probs.sum(dim=1).sum(dim=0)
# Aggregate across tensor-parallel ranks so every rank selects the same
# token indices.
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(scores)
def _topk_kv_compact_slot_mapping(
*,
token_scores: Optional[torch.Tensor], # [T] float32
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
query_start_loc: torch.Tensor, # [B+1]
seq_lens: torch.Tensor, # [B] int32
block_table: torch.Tensor, # [B, max_blocks]
block_size: int,
max_query_len: Optional[int] = None,
topk_budget_max: Optional[int] = None,
) -> torch.Tensor:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device = must_keep.device
T = int(must_keep.numel())
B = int(topk_budget.numel())
dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
if T == 0 or B == 0:
return dst_slots
# Per-request segment boundaries in the packed [T] layout.
# NOTE: `query_start_loc` is already sliced to [B+1] by the model runner.
starts = query_start_loc[:B].to(torch.long)
ends = query_start_loc[1:B + 1].to(torch.long)
lengths = ends - starts # [B]
if lengths.numel() == 0:
return dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max = int(max_query_len) if max_query_len is not None else int(
lengths.max().item())
if L_max <= 0:
return dst_slots
# Map each token to its (request, offset-within-request) coordinate.
token_idx = torch.arange(T, device=device, dtype=torch.long)
# For monotonic `ends` (cu_seqlens), this returns the request id for each
# token in the packed layout.
# Use right=True so that idx==ends[b] maps to the *next* request (b+1),
# i.e., request segments are [start, end) in the packed layout.
req_ids = torch.bucketize(token_idx, ends, right=True) # [T]
start_per_token = starts.index_select(0, req_ids)
pos_in_req = token_idx - start_per_token # [T] in [0, L_b)
# Clamp the per-request top-k budget to the number of candidate tokens
# (excluding must_keep).
must_keep_counts = torch.zeros(B, device=device, dtype=torch.long)
must_keep_counts.scatter_add_(0, req_ids, must_keep.to(torch.long))
cand_counts = (lengths.to(torch.long) - must_keep_counts).clamp_min(0)
k_eff = torch.minimum(topk_budget.to(torch.long).clamp_min(0), cand_counts)
# Prefer an upper bound from CPU (piecewise graph), to avoid sync.
if topk_budget_max is not None:
k_max = min(int(topk_budget_max), L_max)
else:
k_max = int(k_eff.max().item())
# Build a padded [B, L_max] score matrix for a single batched Top-K call.
# Must-keep and padding positions are set to -inf to avoid selection.
keep_mask = must_keep.clone()
if k_max > 0:
if token_scores is None:
raise ValueError("token_scores must be provided when k_max > 0.")
masked_scores = token_scores.to(dtype=torch.float32).masked_fill(
must_keep, float("-inf"))
scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
linear = req_ids * L_max + pos_in_req
scores_flat[linear] = masked_scores
scores = scores_flat.view(B, L_max)
topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max]
# Select only the first k_eff[b] entries for each request b.
col_mask = torch.arange(k_max, device=device).unsqueeze(
0) < k_eff.unsqueeze(1) # [B, k_max]
# Avoid host-side synchronization from dynamic indexing. Instead, mark
# selected tokens via a fixed-size scatter-add.
global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B, k_max]
flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
flat_val = col_mask.reshape(-1).to(torch.int32)
tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
tmp.scatter_add_(0, flat_idx, flat_val)
keep_mask |= tmp > 0
# Compute segment-local ranks (0..kept-1) for kept tokens, preserving token
# order within each request, without dynamic indexing (graph-friendly).
keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T]
start_minus_1 = (starts - 1).clamp_min(0)
prefix_before_all = keep_prefix.index_select(0, start_minus_1.to(torch.long))
prefix_before = torch.where(starts > 0, prefix_before_all,
torch.zeros_like(prefix_before_all)) # [B]
prefix_before_per_token = prefix_before.index_select(0, req_ids) # [T]
local_rank = keep_prefix - prefix_before_per_token - 1 # [T]
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv = (seq_lens[:B].to(torch.long) - lengths.to(torch.long)).clamp_min(0)
base_kv_per_token = base_kv.index_select(0, req_ids) # [T]
dest_pos = base_kv_per_token + local_rank # [T]
dest_block_idx = dest_pos // block_size
dest_off = dest_pos - dest_block_idx * block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks = int(block_table.shape[1])
dest_block_idx_safe = dest_block_idx.clamp_(0, max_blocks - 1).to(torch.long)
block_nums = block_table[req_ids, dest_block_idx_safe]
dest_slot = block_nums.to(torch.long) * block_size + dest_off
dst_slots = torch.where(keep_mask, dest_slot.to(torch.int64), dst_slots)
return dst_slots
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import torch
import vllm.envs as envs
from vllm.v1.kv_compression.slot_mapping import topk_kv_compact_slot_mapping
from vllm.v1.kv_compression.snapkv_score import snapkv_like_token_scores
def snapkv_window_for_topk_budget(
*,
topk_budget: torch.Tensor, # [B] int32
window: int,
) -> torch.Tensor:
"""Build per-request SnapKV window sizes for mixed batches.
Requests with a zero Top-K budget do not need token scores; setting their
window to 0 lets the Triton scoring kernel early-return.
"""
return torch.where(
topk_budget > 0,
torch.full_like(topk_budget, int(window)),
torch.zeros_like(topk_budget),
)
def compute_compact_dst_slots_for_step(
*,
query: torch.Tensor, # [T, Hq, D] for this step
key: torch.Tensor, # [T, Hkv, D] for this step
query_start_loc: torch.Tensor, # [B+1]
seq_lens: torch.Tensor, # [B] int32
block_table: torch.Tensor, # [B, max_blocks]
block_size: int,
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
topk_budget_max: int,
max_query_len: int,
sm_scale: float,
) -> torch.Tensor:
"""Compute per-token KV compaction destinations for one step."""
token_scores = None
if int(topk_budget_max) > 0:
w = snapkv_window_for_topk_budget(
topk_budget=topk_budget,
window=int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW),
)
token_scores = snapkv_like_token_scores(
query=query,
key=key,
query_start_loc=query_start_loc,
window=w,
sm_scale=float(sm_scale),
)
return topk_kv_compact_slot_mapping(
token_scores=token_scores,
must_keep=must_keep,
topk_budget=topk_budget,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
block_table=block_table,
block_size=int(block_size),
max_query_len=int(max_query_len),
topk_budget_max=int(topk_budget_max),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any, Optional, Protocol
import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
from vllm.v1.kv_compression.compaction_step import compute_compact_dst_slots_for_step
from vllm.v1.kv_compression.forward_context import (
get_kv_compression_compact_slots,
get_kv_compression_prompt_payload,
set_kv_compression_compact_slots,
set_kv_compression_prompt_payload,
)
from vllm.v1.kv_compression.prompt_end import compute_prompt_end_indices
from vllm.v1.kv_compression.slot_mapping import kv_compaction_dst_rewrite_mapping
class _ReshapeAndCacheFn(Protocol):
def __call__(
self,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None: ...
def maybe_compute_prompt_end_payload_flash_attn(
*,
kv_sharing_target_layer_name: Optional[str],
query: torch.Tensor,
num_actual_tokens: int,
key_cache: torch.Tensor,
cache_block_size: int,
attn_metadata: Any,
sm_scale: float,
) -> None:
"""Compute and stash prompt-end Top-K indices for chunked-prefill scheme 3.
The payload is cached in the forward context and later consumed by the
model runner to perform one-shot prompt KV compaction before the first
decode step.
"""
if not envs.VLLM_ENABLE_KV_COMPRESSION or kv_sharing_target_layer_name is not None:
return
prompt_end = getattr(attn_metadata, "kv_compression_prompt_end", None)
prompt_lens = getattr(attn_metadata, "kv_compression_prompt_lens", None)
topk_keep = getattr(attn_metadata, "kv_compression_prompt_topk_keep", None)
if prompt_end is None or prompt_lens is None or topk_keep is None:
return
forward_context = get_forward_context()
if get_kv_compression_prompt_payload(forward_context) is not None:
return
payload = compute_prompt_end_indices(
query=query[:num_actual_tokens],
key_cache=key_cache,
block_size=cache_block_size,
query_start_loc=attn_metadata.query_start_loc,
block_table=attn_metadata.block_table,
prompt_end=prompt_end,
prompt_lens=prompt_lens,
topk_keep=topk_keep,
topk_keep_max=getattr(attn_metadata, "kv_compression_prompt_topk_keep_max",
None),
sm_scale=sm_scale,
)
if payload is not None:
set_kv_compression_prompt_payload(forward_context, payload)
def maybe_compact_kv_cache_flash_attn(
*,
kv_sharing_target_layer_name: Optional[str],
layer: Any,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_actual_tokens: int,
cache_block_size: int,
attn_metadata: Any,
sm_scale: float,
kv_cache_dtype: str,
reshape_and_cache: _ReshapeAndCacheFn,
) -> None:
"""Optional per-step KV compaction for scheme 1/2 token-shared selection."""
if not envs.VLLM_ENABLE_KV_COMPRESSION or kv_sharing_target_layer_name is not None:
return
must_keep = getattr(attn_metadata, "kv_compression_must_keep", None)
topk_budget = getattr(attn_metadata, "kv_compression_topk_budget", None)
if must_keep is None or topk_budget is None:
return
forward_context = get_forward_context()
per_layer_topk = envs.VLLM_KV_COMPRESSION_TOPK_PER_LAYER
dst = get_kv_compression_compact_slots(
forward_context,
per_layer_topk=per_layer_topk,
layer=layer,
)
if dst is None:
topk_budget_max = int(
getattr(attn_metadata, "kv_compression_topk_budget_max", 0) or 0)
dst = compute_compact_dst_slots_for_step(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
block_table=attn_metadata.block_table,
block_size=cache_block_size,
must_keep=must_keep,
topk_budget=topk_budget,
topk_budget_max=topk_budget_max,
max_query_len=attn_metadata.max_query_len,
sm_scale=sm_scale,
)
set_kv_compression_compact_slots(
forward_context,
per_layer_topk=per_layer_topk,
layer=layer,
dst=dst,
)
if dst is None:
return
src = attn_metadata.slot_mapping
dst_rewrite = kv_compaction_dst_rewrite_mapping(dst_slots=dst,
src_slots=src)
if not current_platform.is_rocm():
reshape_and_cache(
key,
value,
key_cache,
value_cache,
dst_rewrite,
kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
return
# ROCm: optionally prefer the optimized reshape-and-cache kernel.
if (envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype
and key.dtype == torch.float16):
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_rewrite,
kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
reshape_and_cache(
key,
value,
key_cache,
value_cache,
dst_rewrite,
kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any, Optional
import torch
_PROMPT_PAYLOAD_ATTR = "_kv_compression_prompt_payload"
_COMPACT_SLOTS_ATTR = "_kv_compression_compact_slots"
_COMPACT_SLOTS_BY_LAYER_ATTR = "_kv_compression_compact_slots_by_layer"
def get_kv_compression_prompt_payload(
forward_context: Any,
) -> Optional[dict[str, torch.Tensor]]:
return getattr(forward_context, _PROMPT_PAYLOAD_ATTR, None)
def set_kv_compression_prompt_payload(
forward_context: Any,
payload: dict[str, torch.Tensor],
) -> None:
setattr(forward_context, _PROMPT_PAYLOAD_ATTR, payload)
def _kv_compression_layer_key(layer: Any) -> str:
layer_name = getattr(layer, "layer_name", None)
if layer_name is None:
layer_name = str(id(layer))
return str(layer_name)
def get_kv_compression_compact_slots(
forward_context: Any,
*,
per_layer_topk: bool,
layer: Any,
) -> Optional[torch.Tensor]:
if per_layer_topk:
dst_by_layer = getattr(forward_context, _COMPACT_SLOTS_BY_LAYER_ATTR,
None)
if dst_by_layer is None:
return None
return dst_by_layer.get(_kv_compression_layer_key(layer))
return getattr(forward_context, _COMPACT_SLOTS_ATTR, None)
def set_kv_compression_compact_slots(
forward_context: Any,
*,
per_layer_topk: bool,
layer: Any,
dst: torch.Tensor,
) -> None:
if per_layer_topk:
dst_by_layer = getattr(forward_context, _COMPACT_SLOTS_BY_LAYER_ATTR,
None)
if dst_by_layer is None:
dst_by_layer = {}
setattr(forward_context, _COMPACT_SLOTS_BY_LAYER_ATTR, dst_by_layer)
dst_by_layer[_kv_compression_layer_key(layer)] = dst
else:
setattr(forward_context, _COMPACT_SLOTS_ATTR, dst)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import torch
from vllm.platforms import current_platform
def paged_k_cache_view_for_triton_gather(
*,
key_cache: torch.Tensor,
block_size: int,
) -> torch.Tensor:
"""Return a KV-cache key view in [num_blocks, H, block_size, D] layout.
Supports both:
- [num_blocks, block_size, H, D] (typical CUDA FlashAttention v1 layout)
- [num_blocks, H, block_size, D] (ROCm FlashAttention v1, or external
connectors that expose the cache in HND shape)
"""
if key_cache.ndim != 4:
raise ValueError("key_cache must be a 4D tensor.")
# Common case: [B, T, H, D] -> [B, H, T, D]
if int(key_cache.shape[1]) == int(block_size):
return key_cache.permute(0, 2, 1, 3)
# Already in [B, H, T, D] (ROCm / HND-shaped external caches).
if int(key_cache.shape[2]) == int(block_size):
return key_cache
# Fallback: preserve historical behavior.
if current_platform.is_rocm():
return key_cache
return key_cache.permute(0, 2, 1, 3)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional
import torch
import vllm.envs as envs
@dataclass
class KVCompressionAttentionMetadata:
"""Per-batch KV compression metadata consumed by attention backends."""
must_keep: Optional[torch.Tensor] = None
topk_budget: Optional[torch.Tensor] = None
topk_budget_max: Optional[int] = None
prompt_end: Optional[torch.Tensor] = None
prompt_lens: Optional[torch.Tensor] = None
prompt_topk_keep: Optional[torch.Tensor] = None
prompt_topk_keep_max: Optional[int] = None
def build_kv_compression_attn_metadata(
*,
runner: Any,
num_reqs: int,
num_actual_tokens: int,
) -> KVCompressionAttentionMetadata:
"""Build KV compression metadata for one attention step.
This helper keeps backend code thin and centralizes the logic for selecting
between per-step compaction (scheme 1/2) and prompt-end one-shot scoring
(scheme 3).
"""
meta = KVCompressionAttentionMetadata()
if not envs.VLLM_ENABLE_KV_COMPRESSION:
return meta
# Scheme 1/2: compute compaction destinations every step.
if getattr(runner, "kv_compression_needs_compaction", False):
meta.must_keep = runner.kv_compression_must_keep[:num_actual_tokens]
meta.topk_budget = runner.kv_compression_topk_budget[:num_reqs]
# Avoid device->host sync by reading from the CPU staging buffer.
if num_reqs > 0:
meta.topk_budget_max = int(
runner.kv_compression_topk_budget_np[:num_reqs].max())
else:
meta.topk_budget_max = 0
return meta
# Scheme 3: compute global prompt indices only on the last prefill chunk,
# and perform the actual cache compaction before the first decode step.
scheduler_config = getattr(runner, "scheduler_config", None)
if scheduler_config is None or not getattr(scheduler_config,
"chunked_prefill_enabled",
False):
return meta
if num_reqs <= 0:
return meta
if not runner.kv_compression_prompt_end_np[:num_reqs].any():
return meta
meta.prompt_end = runner.kv_compression_prompt_end[:num_reqs]
meta.prompt_lens = runner.kv_compression_prompt_lens[:num_reqs]
meta.prompt_topk_keep = runner.kv_compression_prompt_topk_keep[:num_reqs]
meta.prompt_topk_keep_max = int(
getattr(runner, "kv_compression_prompt_topk_keep_max", 0) or 0)
return meta
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Optional
import torch
import vllm.envs as envs
from vllm.v1.kv_compression.kv_cache_view import paged_k_cache_view_for_triton_gather
from vllm.v1.kv_compression.snapkv_score import snapkv_query_aware_token_scores
from vllm.v1.kv_compression.topk_select import (_packed_varlen_coords,
_topk_keep_mask_and_local_rank)
def _prompt_end_topk_keep_indices(
*,
token_scores: torch.Tensor, # [T] float32
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32 (candidates only)
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
topk_keep_max: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device = token_scores.device
B = int(prompt_lens.numel())
if B == 0:
empty = torch.empty((0, 0), device=device, dtype=torch.int32)
return empty, torch.empty((0, ), device=device, dtype=torch.int32)
prompt_lens_i64 = prompt_lens.to(torch.long)
cu = torch.zeros((B + 1, ), device=device, dtype=torch.long)
cu[1:] = torch.cumsum(prompt_lens_i64, dim=0)
T = int(token_scores.numel())
if T == 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, torch.zeros((B, ), device=device, dtype=torch.int32)
starts, _, lengths, req_ids, pos_in_req = _packed_varlen_coords(
cu_seqlens=cu,
total_tokens=T,
)
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_prefix, 0))
suffix = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_suffix, 0))
suffix_start = (prompt_lens_i64 - suffix).clamp_min(0)
prefix_len_t = prefix_len.index_select(0, req_ids)
suffix_start_t = suffix_start.index_select(0, req_ids)
must_keep = (pos_in_req < prefix_len_t) | (pos_in_req >= suffix_start_t)
if keep_last_token:
last = (prompt_lens_i64 - 1).clamp_min(0)
last_t = last.index_select(0, req_ids)
must_keep |= pos_in_req == last_t
keep_mask, local_rank, keep_len = _topk_keep_mask_and_local_rank(
token_scores=token_scores,
must_keep=must_keep,
topk_budget=topk_keep,
starts=starts,
lengths=lengths,
req_ids=req_ids,
pos_in_req=pos_in_req,
max_len=int(prompt_lens_i64.max().item()),
topk_budget_max=topk_keep_max,
)
keep_max_len = int(keep_len.max().item()) if B > 0 else 0
if keep_max_len <= 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, keep_len
idx_sorted = torch.zeros((B, keep_max_len), device=device, dtype=torch.int32)
lin_out = (req_ids * keep_max_len + local_rank).masked_select(keep_mask)
vals = pos_in_req.to(torch.int32).masked_select(keep_mask)
idx_sorted.view(-1).scatter_(0, lin_out, vals)
return idx_sorted, keep_len
def compute_prompt_end_indices(
*,
query: torch.Tensor, # [T, Hq, D] scheduled tokens for this step
key_cache: torch.Tensor, # layer KV cache view (platform-dependent)
block_size: int,
query_start_loc: torch.Tensor, # [B+1] int32
block_table: torch.Tensor, # [B, max_blocks] int32
prompt_end: torch.Tensor, # [B] bool
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32
topk_keep_max: Optional[int],
sm_scale: float,
) -> Optional[dict[str, torch.Tensor]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device = query.device
if prompt_end.numel() == 0:
return None
sel = torch.nonzero(prompt_end, as_tuple=False).flatten()
if int(sel.numel()) == 0:
return None
window = int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW)
keep_last = bool(envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN)
protected_prefix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX)
protected_suffix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX)
# Build packed Q window (last `window` queries per selected request).
sel_list = sel.to(device="cpu", dtype=torch.int64).tolist()
qsl = query_start_loc.to(device="cpu", dtype=torch.int64).tolist()
q_chunks = []
cu_q = [0]
w_list = []
for b in sel_list:
s = int(qsl[b])
e = int(qsl[b + 1])
q_len = max(0, e - s)
win = min(window, q_len)
w_list.append(int(win))
if win > 0:
q_chunks.append(query[e - win:e])
cu_q.append(cu_q[-1] + int(win))
if cu_q[-1] <= 0:
return None
q_packed = torch.cat(q_chunks, dim=0) if q_chunks else query[:0]
cu_seqlens_q = torch.tensor(cu_q, device=device, dtype=torch.int32)
w = torch.tensor(w_list, device=device, dtype=torch.int32)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel = prompt_lens.index_select(0, sel).to(torch.int32)
topk_keep_sel = topk_keep.index_select(0, sel).to(torch.int32)
cu_seqlens_k = torch.zeros((int(prompt_lens_sel.numel()) + 1, ),
device=device,
dtype=torch.int32)
if int(prompt_lens_sel.numel()) > 0:
cu_seqlens_k[1:] = torch.cumsum(prompt_lens_sel, dim=0)
block_table_sel = block_table.index_select(0, sel).to(torch.int32)
key_cache_view = paged_k_cache_view_for_triton_gather(
key_cache=key_cache,
block_size=int(block_size),
)
from vllm.v1.kv_compression.kv_cache_triton import (
gather_k_to_packed_triton)
k_packed = gather_k_to_packed_triton(
key_cache_view,
block_table_sel,
prompt_lens_sel,
cu_seqlens_k,
)
token_scores = snapkv_query_aware_token_scores(
query=q_packed,
key=k_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
window=w,
sm_scale=float(sm_scale),
)
idx_sorted, keep_len = _prompt_end_topk_keep_indices(
token_scores=token_scores,
prompt_lens=prompt_lens_sel,
topk_keep=topk_keep_sel,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
topk_keep_max=topk_keep_max,
)
return {
"req_indices": sel.to(torch.int32),
"idx_sorted": idx_sorted, # [B_sel, K_max] int32
"keep_len": keep_len, # [B_sel] int32
"prompt_lens": prompt_lens_sel, # [B_sel] int32
}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Optional
import torch
from vllm.v1.kv_compression.topk_select import (_packed_varlen_coords,
_topk_keep_mask_and_local_rank)
def _dst_slots_from_keep_mask_and_local_rank(
*,
keep_mask: torch.Tensor, # [T] bool
local_rank: torch.Tensor, # [T] int64
seq_lens: torch.Tensor, # [B] int32
lengths: torch.Tensor, # [B] int64
req_ids: torch.Tensor, # [T] int64
block_table: torch.Tensor, # [B, max_blocks] int32
block_size: int,
) -> torch.Tensor:
"""Convert keep_mask/local_rank into a per-token KV destination slot."""
device = keep_mask.device
T = int(keep_mask.numel())
dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
if T == 0:
return dst_slots
B = int(seq_lens.numel())
if B == 0:
return dst_slots
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv = (seq_lens[:B].to(torch.long) - lengths.to(torch.long)).clamp_min(0)
base_kv_per_token = base_kv.index_select(0, req_ids) # [T]
dest_pos = base_kv_per_token + local_rank # [T]
dest_block_idx = dest_pos // block_size
dest_off = dest_pos - dest_block_idx * block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks = int(block_table.shape[1])
dest_block_idx_safe = dest_block_idx.clamp_(0, max_blocks - 1).to(torch.long)
block_nums = block_table[req_ids, dest_block_idx_safe]
dest_slot = block_nums.to(torch.long) * block_size + dest_off
return torch.where(keep_mask, dest_slot.to(torch.int64), dst_slots)
def topk_kv_compact_slot_mapping(
*,
token_scores: Optional[torch.Tensor], # [T] float32
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
query_start_loc: torch.Tensor, # [B+1]
seq_lens: torch.Tensor, # [B] int32
block_table: torch.Tensor, # [B, max_blocks]
block_size: int,
max_query_len: Optional[int] = None,
topk_budget_max: Optional[int] = None,
) -> torch.Tensor:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device = must_keep.device
T = int(must_keep.numel())
B = int(topk_budget.numel())
dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
if T == 0 or B == 0:
return dst_slots
starts, _, lengths, req_ids, pos_in_req = _packed_varlen_coords(
cu_seqlens=query_start_loc,
total_tokens=T,
)
if lengths.numel() == 0:
return dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max = int(max_query_len) if max_query_len is not None else int(
lengths.max().item())
if L_max <= 0:
return dst_slots
keep_mask, local_rank, _ = _topk_keep_mask_and_local_rank(
token_scores=token_scores,
must_keep=must_keep,
topk_budget=topk_budget,
starts=starts,
lengths=lengths,
req_ids=req_ids,
pos_in_req=pos_in_req,
max_len=L_max,
topk_budget_max=topk_budget_max,
)
return _dst_slots_from_keep_mask_and_local_rank(
keep_mask=keep_mask,
local_rank=local_rank,
seq_lens=seq_lens[:B],
lengths=lengths,
req_ids=req_ids,
block_table=block_table,
block_size=int(block_size),
)
def kv_compaction_dst_rewrite_mapping(
*,
dst_slots: torch.Tensor, # [T] int64
src_slots: torch.Tensor, # [T] int64
) -> torch.Tensor:
"""Filter a dst slot mapping so only moved kept tokens are rewritten.
Non-rewrite tokens are marked as -1, which the cache kernels treat as
padding and skip.
"""
rewrite_mask = (dst_slots >= 0) & (dst_slots != src_slots)
return torch.where(rewrite_mask, dst_slots, -1)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Union
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
logger = init_logger(__name__)
_DISABLE_SNAPKV_TRITON: bool = False
def snapkv_query_aware_token_scores(
*,
query: torch.Tensor, # [N_q, Hq, D]
key: torch.Tensor, # [N_k, Hkv, D]
cu_seqlens_q: torch.Tensor, # [B+1]
cu_seqlens_k: torch.Tensor, # [B+1]
window: Union[int, torch.Tensor],
sm_scale: float,
) -> torch.Tensor:
"""Compute token-shared SnapKV scores for packed, varlen q/k inputs.
Returns a [N_k] float32 tensor, reduced across TP ranks so every rank makes
an identical Top-K selection.
"""
global _DISABLE_SNAPKV_TRITON
device = query.device
if query.ndim != 3 or key.ndim != 3:
raise ValueError("query and key must be 3D tensors.")
_, Hq, D = query.shape
N_k, Hkv, Dk = key.shape
if D != Dk:
raise ValueError("query and key must have the same head size.")
if Hq % Hkv != 0:
raise ValueError("Query heads must be a multiple of KV heads.")
if cu_seqlens_q.numel() != cu_seqlens_k.numel():
raise ValueError("cu_seqlens_q and cu_seqlens_k must match.")
# NOTE: Triton SnapKV scoring on ROCm is experimental. It is enabled by
# default (uses a ROCm-safe kernel); set
# VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM=0 to force the PyTorch
# reference implementation.
if (HAS_TRITON and not _DISABLE_SNAPKV_TRITON and device.type == "cuda"
and (not current_platform.is_rocm()
or envs.VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM)
and query.stride(-1) == 1 and key.stride(-1) == 1):
try:
from vllm.v1.kv_compression.snapkv_triton import (
query_aware_key_scores)
w = int(window) if isinstance(window, int) else window
scores_per_head = query_aware_key_scores(
q=query,
k=key,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
w=w,
sm_scale=float(sm_scale),
pool=False,
protect_last=False,
normalize=False,
)
token_scores = scores_per_head.sum(dim=1)
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(token_scores)
except Exception as e:
_DISABLE_SNAPKV_TRITON = True
logger.warning(
"Triton SnapKV scoring failed; falling back to PyTorch. "
"Error: %s", e)
group = Hq // Hkv
# Read boundaries and (optional) per-request window sizes on host.
qsl = cu_seqlens_q.tolist()
ksl = cu_seqlens_k.tolist()
B = len(qsl) - 1
if len(ksl) - 1 != B:
raise ValueError("cu_seqlens_q and cu_seqlens_k must match.")
wsl = None
if not isinstance(window, int):
if int(window.numel()) != B:
raise ValueError("window must be a scalar int or have shape [B].")
wsl = window.to(device="cpu", dtype=torch.int64).tolist()
scores = torch.zeros((N_k, ), device=device, dtype=torch.float32)
for b in range(B):
qs = int(qsl[b])
qe = int(qsl[b + 1])
ks = int(ksl[b])
ke = int(ksl[b + 1])
q_len = qe - qs
k_len = ke - ks
if q_len <= 0 or k_len <= 0:
continue
win_b = int(window) if wsl is None else int(wsl[b])
if win_b <= 0:
continue
win = min(win_b, q_len, k_len)
k_eff_end = ke - win
if k_eff_end <= ks:
continue
q_win = query[qe - win:qe] # [win, Hq, D]
# Aggregate query heads to KV heads (token-shared selection).
q_win = q_win.reshape(win, Hkv, group, D).mean(dim=2) # [win, Hkv, D]
k_eff = key[ks:k_eff_end] # [K_eff, Hkv, D]
qh = q_win.permute(1, 0, 2).to(torch.float32) # [Hkv, win, D]
kh = k_eff.permute(1, 0, 2).to(torch.float32) # [Hkv, K_eff, D]
logits = torch.matmul(qh, kh.transpose(1, 2)) * float(sm_scale)
probs = torch.softmax(logits, dim=-1)
scores[ks:k_eff_end] = probs.sum(dim=1).sum(dim=0)
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(scores)
def snapkv_like_token_scores(
*,
query: torch.Tensor, # [T, Hq, D]
key: torch.Tensor, # [T, Hkv, D]
query_start_loc: torch.Tensor, # [B+1]
window: Union[int, torch.Tensor],
sm_scale: float,
) -> torch.Tensor:
"""SnapKV-like token scores when q/k share the same packed layout."""
return snapkv_query_aware_token_scores(
query=query,
key=key,
cu_seqlens_q=query_start_loc,
cu_seqlens_k=query_start_loc,
window=window,
sm_scale=sm_scale,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Optional
import torch
def _packed_varlen_coords(
*,
cu_seqlens: torch.Tensor, # [B+1]
total_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute packed varlen segment coordinates.
Returns:
starts: [B] int64, segment start offsets (inclusive)
ends: [B] int64, segment end offsets (exclusive)
lengths: [B] int64, segment lengths (ends - starts)
req_ids: [T] int64, request id for each token in packed [0, T)
pos_in_req: [T] int64, position within its request segment
"""
device = cu_seqlens.device
B = int(cu_seqlens.numel() - 1)
if B <= 0:
empty = torch.empty((0, ), device=device, dtype=torch.long)
t_empty = torch.empty((0, ), device=device, dtype=torch.long)
return empty, empty, empty, t_empty, t_empty
starts = cu_seqlens[:B].to(torch.long)
ends = cu_seqlens[1:B + 1].to(torch.long)
lengths = ends - starts
if total_tokens <= 0:
t_empty = torch.empty((0, ), device=device, dtype=torch.long)
return starts, ends, lengths, t_empty, t_empty
token_idx = torch.arange(total_tokens, device=device, dtype=torch.long)
req_ids = torch.bucketize(token_idx, ends, right=True) # [T]
start_per_token = starts.index_select(0, req_ids)
pos_in_req = token_idx - start_per_token
return starts, ends, lengths, req_ids, pos_in_req
def _topk_keep_mask_and_local_rank(
*,
token_scores: Optional[torch.Tensor], # [T] float32
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
starts: torch.Tensor, # [B] int64
lengths: torch.Tensor, # [B] int64
req_ids: torch.Tensor, # [T] int64
pos_in_req: torch.Tensor, # [T] int64
max_len: Optional[int] = None,
topk_budget_max: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute keep_mask/local_rank for token-shared Top-K selection.
Returns:
keep_mask: [T] bool, selected tokens (includes must_keep)
local_rank: [T] int64, rank among kept tokens within each request
keep_len: [B] int32, number of kept tokens per request
"""
device = must_keep.device
T = int(must_keep.numel())
B = int(topk_budget.numel())
keep_mask = must_keep.clone()
if T == 0 or B == 0:
local_rank = torch.empty((T, ), device=device, dtype=torch.long)
keep_len = torch.zeros((B, ), device=device, dtype=torch.int32)
return keep_mask, local_rank, keep_len
if max_len is None:
L_max = int(lengths.max().item()) if lengths.numel() > 0 else 0
else:
L_max = int(max_len)
if L_max < 0:
L_max = 0
must_keep_counts = torch.zeros((B, ), device=device, dtype=torch.long)
must_keep_counts.scatter_add_(0, req_ids, must_keep.to(torch.long))
cand_counts = (lengths.to(torch.long) - must_keep_counts).clamp_min(0)
k_eff = torch.minimum(topk_budget.to(torch.long).clamp_min(0), cand_counts)
# CPU-known bound avoids a device->host sync; clamp for safety.
if topk_budget_max is None:
k_max = int(k_eff.max().item()) if k_eff.numel() > 0 else 0
else:
k_max = int(topk_budget_max)
if k_max < 0:
k_max = 0
if k_max > L_max:
k_max = L_max
if k_max > 0:
if token_scores is None:
raise ValueError("token_scores must be provided when k_max > 0.")
masked_scores = token_scores.to(torch.float32).masked_fill(
must_keep, float("-inf"))
scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
linear = req_ids * L_max + pos_in_req
scores_flat[linear] = masked_scores
scores = scores_flat.view(B, L_max)
topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max]
col_mask = torch.arange(k_max,
device=device).unsqueeze(0) < k_eff.unsqueeze(1)
global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B,k_max]
flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
flat_val = col_mask.reshape(-1).to(torch.int32)
tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
tmp.scatter_add_(0, flat_idx, flat_val)
keep_mask |= tmp > 0
keep_len = torch.zeros((B, ), device=device, dtype=torch.long)
keep_len.scatter_add_(0, req_ids, keep_mask.to(torch.long))
# Stable, order-preserving local rank using segment-local prefix sums.
keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T]
start_minus_1 = (starts - 1).clamp_min(0)
prefix_before_all = keep_prefix.index_select(0, start_minus_1)
prefix_before = torch.where(starts > 0, prefix_before_all,
torch.zeros_like(prefix_before_all)) # [B]
prefix_before_per_token = prefix_before.index_select(0, req_ids) # [T]
local_rank = keep_prefix - prefix_before_per_token - 1 # [T]
return keep_mask, local_rank, keep_len.to(torch.int32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment