Commit 863f93e6 authored by laibao's avatar laibao
Browse files

feat: kvpress flash_attn 实现非 chunked Top‑K compaction

parent a9ebf337
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union
import numpy as np
import torch
......@@ -33,6 +33,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
make_local_attention_virtual_batches)
......@@ -43,6 +44,7 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
_DISABLE_SNAPKV_TRITON: bool = False
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
......@@ -592,8 +594,10 @@ class FlashAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
cache_block_size = key_cache.shape[-3]
else:
key_cache, value_cache = kv_cache
cache_block_size = key_cache.shape[-2]
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
......@@ -751,6 +755,130 @@ class FlashAttentionImpl(AttentionImpl):
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache=True,
)
# Optional KV compaction pass for token-shared KV compression.
# This rewrites a selected subset of newly written KV entries into
# a packed layout for the next step.
if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.kv_sharing_target_layer_name is None):
dst = None
if (attn_metadata.kv_compression_must_keep is not None
and attn_metadata.kv_compression_topk_budget
is not None):
forward_context = get_forward_context()
per_layer_topk = envs.VLLM_KV_COMPRESSION_TOPK_PER_LAYER
if per_layer_topk:
layer_name = getattr(layer, "layer_name", None)
if layer_name is None:
layer_name = str(id(layer))
dst_by_layer = getattr(
forward_context, "_kv_compression_compact_slots_by_layer",
None)
if dst_by_layer is None:
dst_by_layer = {}
setattr(
forward_context,
"_kv_compression_compact_slots_by_layer",
dst_by_layer,
)
dst = dst_by_layer.get(layer_name)
else:
dst = getattr(forward_context,
"_kv_compression_compact_slots", None)
if dst is None:
topk_budget = attn_metadata.kv_compression_topk_budget
token_scores: Optional[torch.Tensor] = None
# If there is no Top-K budget for any request in this
# step, selection does not depend on token scores.
# Skipping SnapKV scoring avoids unnecessary compute.
topk_budget_max = int(
attn_metadata.kv_compression_topk_budget_max or 0)
if topk_budget_max > 0:
# Mixed batch optimization: avoid scoring requests
# with a zero Top-K budget by setting their
# per-request window to 0 (kernel early-return).
window = int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW)
w = torch.where(
topk_budget > 0,
torch.full_like(topk_budget, window),
torch.zeros_like(topk_budget),
)
token_scores = _snapkv_like_token_scores(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
query_start_loc=attn_metadata.query_start_loc,
window=w,
sm_scale=self.scale,
)
dst = _topk_kv_compact_slot_mapping(
token_scores=token_scores,
must_keep=attn_metadata.kv_compression_must_keep,
topk_budget=topk_budget,
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
block_table=attn_metadata.block_table,
block_size=cache_block_size,
max_query_len=attn_metadata.max_query_len,
topk_budget_max=topk_budget_max,
)
if per_layer_topk:
dst_by_layer[layer_name] = dst
else:
setattr(forward_context,
"_kv_compression_compact_slots", dst)
if dst is not None:
src = attn_metadata.slot_mapping
rewrite_mask = (dst >= 0) & (dst != src)
# Avoid host-side synchronization (`torch.any(...)`) and
# dynamic boolean-indexing gathers. Instead, construct a
# per-token destination mapping where non-rewrite tokens
# are marked as -1, which the cache kernels treat as
# padding and skip.
dst_rewrite = torch.where(rewrite_mask, dst, -1)
def _writeback(dst_mapping: torch.Tensor) -> None:
if not current_platform.is_rocm():
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
if (envs.VLLM_USE_OPT_RESHAPE_AND_CACHE
and key.dtype == value.dtype
and key.dtype == torch.float16):
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
from vllm.attention.utils.fa_utils import (
reshape_and_cache_cuda)
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
_writeback(dst_rewrite)
return output
assert not use_local_attn, (
......@@ -1265,3 +1393,225 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)
def _snapkv_like_token_scores(
*,
query: torch.Tensor, # [T, Hq, D]
key: torch.Tensor, # [T, Hkv, D]
query_start_loc: torch.Tensor, # [B+1]
window: Union[int, torch.Tensor],
sm_scale: float,
) -> torch.Tensor:
"""Compute token-shared SnapKV-like scores for a packed varlen batch.
Scores are computed as the attention mass from the last `window` query
tokens to the earlier keys within the same scheduled segment (per request),
summed across KV heads.
Prefers a Triton implementation when available; falls back to a (slower)
PyTorch reference implementation otherwise.
"""
global _DISABLE_SNAPKV_TRITON
device = query.device
T, Hq, D = query.shape
Hkv = key.shape[1]
if Hq % Hkv != 0:
raise ValueError("Query heads must be a multiple of KV heads.")
# NOTE: Triton SnapKV scoring on ROCm is experimental. It is enabled by
# default (uses a ROCm-safe kernel); set
# VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM=0 to force the PyTorch
# reference implementation.
if (HAS_TRITON and not _DISABLE_SNAPKV_TRITON and device.type == "cuda"
and (not current_platform.is_rocm()
or envs.VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM)
and query.stride(-1) == 1 and key.stride(-1) == 1):
try:
from vllm.v1.attention.kv_compression.snapkv_triton import (
query_aware_key_scores)
w = int(window) if isinstance(window, int) else window
scores_per_head = query_aware_key_scores(
q=query,
k=key,
cu_seqlens_q=query_start_loc,
cu_seqlens_k=query_start_loc,
w=w,
sm_scale=float(sm_scale),
pool=False,
protect_last=False,
normalize=False,
)
token_scores = scores_per_head.sum(dim=1)
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(token_scores)
except Exception as e:
_DISABLE_SNAPKV_TRITON = True
logger.warning(
"Triton SnapKV scoring failed; falling back to PyTorch. "
"Error: %s", e)
group = Hq // Hkv
# Read boundaries on host (small tensor).
qsl = query_start_loc.tolist()
B = len(qsl) - 1
wsl = None
if not isinstance(window, int):
if int(window.numel()) != B:
raise ValueError("window must be a scalar int or have shape [B].")
wsl = window.to(device="cpu", dtype=torch.int64).tolist()
scores = torch.zeros((T, ), device=device, dtype=torch.float32)
for b in range(B):
s = int(qsl[b])
e = int(qsl[b + 1])
L = e - s
if L <= 0:
continue
win_b = int(window) if wsl is None else int(wsl[b])
if win_b <= 0:
continue
win = min(win_b, L)
k_eff_end = L - win
if k_eff_end <= 0:
continue
q_win = query[e - win:e] # [win, Hq, D]
# Aggregate query heads to KV heads (token-shared selection).
q_win = q_win.reshape(win, Hkv, group, D).mean(dim=2) # [win, Hkv, D]
k_eff = key[s:s + k_eff_end] # [k_eff_end, Hkv, D]
qh = q_win.permute(1, 0, 2).to(torch.float32) # [Hkv, win, D]
kh = k_eff.permute(1, 0, 2).to(torch.float32) # [Hkv, k_eff_end, D]
logits = torch.matmul(qh, kh.transpose(1, 2)) * sm_scale # [Hkv, win, K]
probs = torch.softmax(logits, dim=-1)
# Sum over (heads, window queries) -> per-key token score.
scores[s:s + k_eff_end] = probs.sum(dim=1).sum(dim=0)
# Aggregate across tensor-parallel ranks so every rank selects the same
# token indices.
from vllm.distributed.parallel_state import get_tp_group
return get_tp_group().all_reduce(scores)
def _topk_kv_compact_slot_mapping(
*,
token_scores: Optional[torch.Tensor], # [T] float32
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
query_start_loc: torch.Tensor, # [B+1]
seq_lens: torch.Tensor, # [B] int32
block_table: torch.Tensor, # [B, max_blocks]
block_size: int,
max_query_len: Optional[int] = None,
topk_budget_max: Optional[int] = None,
) -> torch.Tensor:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device = must_keep.device
T = int(must_keep.numel())
B = int(topk_budget.numel())
dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
if T == 0 or B == 0:
return dst_slots
# Per-request segment boundaries in the packed [T] layout.
# NOTE: `query_start_loc` is already sliced to [B+1] by the model runner.
starts = query_start_loc[:B].to(torch.long)
ends = query_start_loc[1:B + 1].to(torch.long)
lengths = ends - starts # [B]
if lengths.numel() == 0:
return dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max = int(max_query_len) if max_query_len is not None else int(
lengths.max().item())
if L_max <= 0:
return dst_slots
# Map each token to its (request, offset-within-request) coordinate.
token_idx = torch.arange(T, device=device, dtype=torch.long)
# For monotonic `ends` (cu_seqlens), this returns the request id for each
# token in the packed layout.
# Use right=True so that idx==ends[b] maps to the *next* request (b+1),
# i.e., request segments are [start, end) in the packed layout.
req_ids = torch.bucketize(token_idx, ends, right=True) # [T]
start_per_token = starts.index_select(0, req_ids)
pos_in_req = token_idx - start_per_token # [T] in [0, L_b)
# Clamp the per-request top-k budget to the number of candidate tokens
# (excluding must_keep).
must_keep_counts = torch.zeros(B, device=device, dtype=torch.long)
must_keep_counts.scatter_add_(0, req_ids, must_keep.to(torch.long))
cand_counts = (lengths.to(torch.long) - must_keep_counts).clamp_min(0)
k_eff = torch.minimum(topk_budget.to(torch.long).clamp_min(0), cand_counts)
# Prefer an upper bound from CPU (piecewise graph), to avoid sync.
if topk_budget_max is not None:
k_max = min(int(topk_budget_max), L_max)
else:
k_max = int(k_eff.max().item())
# Build a padded [B, L_max] score matrix for a single batched Top-K call.
# Must-keep and padding positions are set to -inf to avoid selection.
keep_mask = must_keep.clone()
if k_max > 0:
if token_scores is None:
raise ValueError("token_scores must be provided when k_max > 0.")
masked_scores = token_scores.to(dtype=torch.float32).masked_fill(
must_keep, float("-inf"))
scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
linear = req_ids * L_max + pos_in_req
scores_flat[linear] = masked_scores
scores = scores_flat.view(B, L_max)
topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max]
# Select only the first k_eff[b] entries for each request b.
col_mask = torch.arange(k_max, device=device).unsqueeze(
0) < k_eff.unsqueeze(1) # [B, k_max]
# Avoid host-side synchronization from dynamic indexing. Instead, mark
# selected tokens via a fixed-size scatter-add.
global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B, k_max]
flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
flat_val = col_mask.reshape(-1).to(torch.int32)
tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
tmp.scatter_add_(0, flat_idx, flat_val)
keep_mask |= tmp > 0
# Compute segment-local ranks (0..kept-1) for kept tokens, preserving token
# order within each request, without dynamic indexing (graph-friendly).
keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T]
start_minus_1 = (starts - 1).clamp_min(0)
prefix_before_all = keep_prefix.index_select(0, start_minus_1.to(torch.long))
prefix_before = torch.where(starts > 0, prefix_before_all,
torch.zeros_like(prefix_before_all)) # [B]
prefix_before_per_token = prefix_before.index_select(0, req_ids) # [T]
local_rank = keep_prefix - prefix_before_per_token - 1 # [T]
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv = (seq_lens[:B].to(torch.long) - lengths.to(torch.long)).clamp_min(0)
base_kv_per_token = base_kv.index_select(0, req_ids) # [T]
dest_pos = base_kv_per_token + local_rank # [T]
dest_block_idx = dest_pos // block_size
dest_off = dest_pos - dest_block_idx * block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks = int(block_table.shape[1])
dest_block_idx_safe = dest_block_idx.clamp_(0, max_blocks - 1).to(torch.long)
block_nums = block_table[req_ids, dest_block_idx_safe]
dest_slot = block_nums.to(torch.long) * block_size + dest_off
dst_slots = torch.where(keep_mask, dest_slot.to(torch.int64), dst_slots)
return dst_slots
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment