Commit da35d84f authored by laibao's avatar laibao
Browse files

feat(kvpress): FlashAttention 接入 KV 压缩 hooks

parent dbcb0376
......@@ -61,6 +61,10 @@ from vllm.v1.attention.backends.utils import (
)
from vllm.v1.kv_cache_interface import AttentionSpec
import vllm.envs as envs
from vllm.v1.kv_compression.flash_attn_hooks import (
maybe_compute_prompt_end_payload_flash_attn,
maybe_compact_kv_cache_flash_attn,
)
logger = init_logger(__name__)
......@@ -248,6 +252,11 @@ class FlashAttentionMetadata:
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
# When input tensors are padded (e.g., for sequence parallelism / piecewise
# CUDA graphs), `num_actual_tokens` may include padding. KV compression
# helpers require the unpadded scheduled token count to match
# `query_start_loc[-1]`.
num_unpadded_tokens: int | None = None
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
......@@ -262,6 +271,17 @@ class FlashAttentionMetadata:
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
# KV compression metadata for token-shared selection.
kv_compression_must_keep: torch.Tensor | None = None
kv_compression_topk_budget: torch.Tensor | None = None
# CPU-known max Top-K budget for this step (avoids device->host sync).
kv_compression_topk_budget_max: int | None = None
# Chunked prefill: prompt-end one-shot scoring/Top-K (scheme 3).
kv_compression_prompt_end: torch.Tensor | None = None # [B] bool
kv_compression_prompt_lens: torch.Tensor | None = None # [B] int32
kv_compression_prompt_topk_keep: torch.Tensor | None = None # [B] int32
kv_compression_prompt_topk_keep_max: int | None = None
# For GQA DCP
max_dcp_context_kv_len: int | None = None
dcp_context_kv_lens: torch.Tensor | None = None
......@@ -546,6 +566,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_unpadded_tokens=common_attn_metadata.num_unpadded_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
......@@ -560,6 +581,13 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
kv_compression_must_keep=common_attn_metadata.kv_compression_must_keep,
kv_compression_topk_budget=common_attn_metadata.kv_compression_topk_budget,
kv_compression_topk_budget_max=common_attn_metadata.kv_compression_topk_budget_max,
kv_compression_prompt_end=common_attn_metadata.kv_compression_prompt_end,
kv_compression_prompt_lens=common_attn_metadata.kv_compression_prompt_lens,
kv_compression_prompt_topk_keep=common_attn_metadata.kv_compression_prompt_topk_keep,
kv_compression_prompt_topk_keep_max=common_attn_metadata.kv_compression_prompt_topk_keep_max,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
causal=causal,
......@@ -729,6 +757,25 @@ class FlashAttentionImpl(AttentionImpl):
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
num_unpadded_tokens = (
attn_metadata.num_unpadded_tokens
if attn_metadata.num_unpadded_tokens is not None
else num_actual_tokens
)
cache_block_size = int(
key_cache.shape[2] if current_platform.is_rocm() else key_cache.shape[1]
)
if envs.VLLM_ENABLE_KV_COMPRESSION:
maybe_compute_prompt_end_payload_flash_attn(
kv_sharing_target_layer_name=self.kv_sharing_target_layer_name,
query=query,
num_actual_tokens=num_unpadded_tokens,
key_cache=key_cache,
cache_block_size=cache_block_size,
attn_metadata=attn_metadata,
sm_scale=self.scale,
)
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
......@@ -761,6 +808,26 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=k_descale,
v_descale=v_descale,
)
if envs.VLLM_ENABLE_KV_COMPRESSION:
maybe_compact_kv_cache_flash_attn(
kv_sharing_target_layer_name=self.kv_sharing_target_layer_name,
layer=layer,
query=query,
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
num_actual_tokens=num_unpadded_tokens,
cache_block_size=cache_block_size,
attn_metadata=attn_metadata,
sm_scale=self.scale,
kv_cache_dtype=self.kv_cache_dtype,
reshape_and_cache=(
reshape_and_cache_cuda
if current_platform.is_rocm()
else reshape_and_cache_flash
),
)
return output
else:
sliding_window_size = (
......@@ -822,6 +889,26 @@ class FlashAttentionImpl(AttentionImpl):
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
if envs.VLLM_ENABLE_KV_COMPRESSION:
maybe_compact_kv_cache_flash_attn(
kv_sharing_target_layer_name=self.kv_sharing_target_layer_name,
layer=layer,
query=query,
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
num_actual_tokens=num_unpadded_tokens,
cache_block_size=cache_block_size,
attn_metadata=attn_metadata,
sm_scale=self.scale,
kv_cache_dtype=self.kv_cache_dtype,
reshape_and_cache=(
reshape_and_cache_cuda
if current_platform.is_rocm()
else reshape_and_cache_flash
),
)
return output
# Cascade attention (rare case).
......@@ -851,6 +938,26 @@ class FlashAttentionImpl(AttentionImpl):
v_descale=layer._v_scale,
s_aux=self.sinks,
)
if envs.VLLM_ENABLE_KV_COMPRESSION:
maybe_compact_kv_cache_flash_attn(
kv_sharing_target_layer_name=self.kv_sharing_target_layer_name,
layer=layer,
query=query,
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
num_actual_tokens=num_unpadded_tokens,
cache_block_size=cache_block_size,
attn_metadata=attn_metadata,
sm_scale=self.scale,
kv_cache_dtype=self.kv_cache_dtype,
reshape_and_cache=(
reshape_and_cache_cuda
if current_platform.is_rocm()
else reshape_and_cache_flash
),
)
return output
def do_kv_cache_update(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import torch
import vllm.envs as envs
from vllm.v1.kv_compression.slot_mapping import topk_kv_compact_slot_mapping
from vllm.v1.kv_compression.snapkv_score import snapkv_like_token_scores
def snapkv_window_for_topk_budget(
*,
topk_budget: torch.Tensor, # [B] int32
window: int,
) -> torch.Tensor:
"""Build per-request SnapKV window sizes for mixed batches.
Requests with a zero Top-K budget do not need token scores; setting their
window to 0 lets the Triton scoring kernel early-return.
"""
return torch.where(
topk_budget > 0,
torch.full_like(topk_budget, int(window)),
torch.zeros_like(topk_budget),
)
def compute_compact_dst_slots_for_step(
*,
query: torch.Tensor, # [T, Hq, D] for this step
key: torch.Tensor, # [T, Hkv, D] for this step
query_start_loc: torch.Tensor, # [B+1]
seq_lens: torch.Tensor, # [B] int32
block_table: torch.Tensor, # [B, max_blocks]
block_size: int,
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
topk_budget_max: int,
max_query_len: int,
sm_scale: float,
) -> torch.Tensor:
"""Compute per-token KV compaction destinations for one step."""
token_scores = None
if int(topk_budget_max) > 0:
w = snapkv_window_for_topk_budget(
topk_budget=topk_budget,
window=int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW),
)
token_scores = snapkv_like_token_scores(
query=query,
key=key,
query_start_loc=query_start_loc,
window=w,
sm_scale=float(sm_scale),
)
return topk_kv_compact_slot_mapping(
token_scores=token_scores,
must_keep=must_keep,
topk_budget=topk_budget,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
block_table=block_table,
block_size=int(block_size),
max_query_len=int(max_query_len),
topk_budget_max=int(topk_budget_max),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any, Optional, Protocol
import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
from vllm.v1.kv_compression.compaction_step import compute_compact_dst_slots_for_step
from vllm.v1.kv_compression.forward_context import (
get_kv_compression_compact_slots,
get_kv_compression_prompt_payload,
set_kv_compression_compact_slots,
set_kv_compression_prompt_payload,
)
from vllm.v1.kv_compression.prompt_end import compute_prompt_end_indices
from vllm.v1.kv_compression.slot_mapping import kv_compaction_dst_rewrite_mapping
class _ReshapeAndCacheFn(Protocol):
def __call__(
self,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None: ...
def maybe_compute_prompt_end_payload_flash_attn(
*,
kv_sharing_target_layer_name: Optional[str],
query: torch.Tensor,
num_actual_tokens: int,
key_cache: torch.Tensor,
cache_block_size: int,
attn_metadata: Any,
sm_scale: float,
) -> None:
"""Compute and stash prompt-end Top-K indices for chunked-prefill scheme 3.
The payload is cached in the forward context and later consumed by the
model runner to perform one-shot prompt KV compaction before the first
decode step.
"""
if not envs.VLLM_ENABLE_KV_COMPRESSION or kv_sharing_target_layer_name is not None:
return
prompt_end = getattr(attn_metadata, "kv_compression_prompt_end", None)
prompt_lens = getattr(attn_metadata, "kv_compression_prompt_lens", None)
topk_keep = getattr(attn_metadata, "kv_compression_prompt_topk_keep", None)
if prompt_end is None or prompt_lens is None or topk_keep is None:
return
B = int(prompt_end.numel())
if B <= 0:
return
forward_context = get_forward_context()
if get_kv_compression_prompt_payload(forward_context) is not None:
return
payload = compute_prompt_end_indices(
query=query[:num_actual_tokens],
key_cache=key_cache,
block_size=cache_block_size,
query_start_loc=attn_metadata.query_start_loc[:B + 1],
block_table=attn_metadata.block_table[:B],
prompt_end=prompt_end,
prompt_lens=prompt_lens,
topk_keep=topk_keep,
topk_keep_max=getattr(attn_metadata, "kv_compression_prompt_topk_keep_max",
None),
sm_scale=sm_scale,
)
if payload is not None:
set_kv_compression_prompt_payload(forward_context, payload)
def maybe_compact_kv_cache_flash_attn(
*,
kv_sharing_target_layer_name: Optional[str],
layer: Any,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_actual_tokens: int,
cache_block_size: int,
attn_metadata: Any,
sm_scale: float,
kv_cache_dtype: str,
reshape_and_cache: _ReshapeAndCacheFn,
) -> None:
"""Optional per-step KV compaction for scheme 1/2 token-shared selection."""
if not envs.VLLM_ENABLE_KV_COMPRESSION or kv_sharing_target_layer_name is not None:
return
must_keep = getattr(attn_metadata, "kv_compression_must_keep", None)
topk_budget = getattr(attn_metadata, "kv_compression_topk_budget", None)
if must_keep is None or topk_budget is None:
return
B = int(topk_budget.numel())
if B <= 0:
return
forward_context = get_forward_context()
per_layer_topk = envs.VLLM_KV_COMPRESSION_TOPK_PER_LAYER
dst = get_kv_compression_compact_slots(
forward_context,
per_layer_topk=per_layer_topk,
layer=layer,
)
if dst is None:
topk_budget_max = int(
getattr(attn_metadata, "kv_compression_topk_budget_max", 0) or 0)
dst = compute_compact_dst_slots_for_step(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
query_start_loc=attn_metadata.query_start_loc[:B + 1],
seq_lens=attn_metadata.seq_lens[:B],
block_table=attn_metadata.block_table[:B],
block_size=cache_block_size,
must_keep=must_keep[:num_actual_tokens],
topk_budget=topk_budget,
topk_budget_max=topk_budget_max,
max_query_len=attn_metadata.max_query_len,
sm_scale=sm_scale,
)
set_kv_compression_compact_slots(
forward_context,
per_layer_topk=per_layer_topk,
layer=layer,
dst=dst,
)
if dst is None:
return
src = attn_metadata.slot_mapping[:num_actual_tokens]
dst_rewrite = kv_compaction_dst_rewrite_mapping(dst_slots=dst,
src_slots=src)
if not current_platform.is_rocm():
reshape_and_cache(
key,
value,
key_cache,
value_cache,
dst_rewrite,
kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
return
# ROCm: optionally prefer the optimized reshape-and-cache kernel.
if (envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype
and key.dtype == torch.float16):
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
dst_rewrite,
kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
reshape_and_cache(
key,
value,
key_cache,
value_cache,
dst_rewrite,
kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Optional
import torch
import vllm.envs as envs
from vllm.v1.kv_compression.kv_cache_view import paged_k_cache_view_for_triton_gather
from vllm.v1.kv_compression.snapkv_score import snapkv_query_aware_token_scores
from vllm.v1.kv_compression.topk_select import (_packed_varlen_coords,
_topk_keep_mask_and_local_rank)
def _prompt_end_topk_keep_indices(
*,
token_scores: torch.Tensor, # [T] float32
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32 (candidates only)
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
topk_keep_max: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device = token_scores.device
B = int(prompt_lens.numel())
if B == 0:
empty = torch.empty((0, 0), device=device, dtype=torch.int32)
return empty, torch.empty((0, ), device=device, dtype=torch.int32)
prompt_lens_i64 = prompt_lens.to(torch.long)
cu = torch.zeros((B + 1, ), device=device, dtype=torch.long)
cu[1:] = torch.cumsum(prompt_lens_i64, dim=0)
T = int(token_scores.numel())
if T == 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, torch.zeros((B, ), device=device, dtype=torch.int32)
starts, _, lengths, req_ids, pos_in_req = _packed_varlen_coords(
cu_seqlens=cu,
total_tokens=T,
)
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_prefix, 0))
suffix = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_suffix, 0))
suffix_start = (prompt_lens_i64 - suffix).clamp_min(0)
prefix_len_t = prefix_len.index_select(0, req_ids)
suffix_start_t = suffix_start.index_select(0, req_ids)
must_keep = (pos_in_req < prefix_len_t) | (pos_in_req >= suffix_start_t)
if keep_last_token:
last = (prompt_lens_i64 - 1).clamp_min(0)
last_t = last.index_select(0, req_ids)
must_keep |= pos_in_req == last_t
keep_mask, local_rank, keep_len = _topk_keep_mask_and_local_rank(
token_scores=token_scores,
must_keep=must_keep,
topk_budget=topk_keep,
starts=starts,
lengths=lengths,
req_ids=req_ids,
pos_in_req=pos_in_req,
max_len=int(prompt_lens_i64.max().item()),
topk_budget_max=topk_keep_max,
)
keep_max_len = int(keep_len.max().item()) if B > 0 else 0
if keep_max_len <= 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, keep_len
idx_sorted = torch.zeros((B, keep_max_len), device=device, dtype=torch.int32)
lin_out = (req_ids * keep_max_len + local_rank).masked_select(keep_mask)
vals = pos_in_req.to(torch.int32).masked_select(keep_mask)
idx_sorted.view(-1).scatter_(0, lin_out, vals)
return idx_sorted, keep_len
def compute_prompt_end_indices(
*,
query: torch.Tensor, # [T, Hq, D] scheduled tokens for this step
key_cache: torch.Tensor, # layer KV cache view (platform-dependent)
block_size: int,
query_start_loc: torch.Tensor, # [B+1] int32
block_table: torch.Tensor, # [B, max_blocks] int32
prompt_end: torch.Tensor, # [B] bool
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32
topk_keep_max: Optional[int],
sm_scale: float,
) -> Optional[dict[str, torch.Tensor]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device = query.device
if prompt_end.numel() == 0:
return None
sel = torch.nonzero(prompt_end, as_tuple=False).flatten()
if int(sel.numel()) == 0:
return None
window = int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW)
keep_last = bool(envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN)
protected_prefix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX)
protected_suffix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX)
# Build packed Q window (last `window` queries per selected request).
sel_list = sel.to(device="cpu", dtype=torch.int64).tolist()
qsl = query_start_loc.to(device="cpu", dtype=torch.int64).tolist()
q_chunks = []
cu_q = [0]
w_list = []
for b in sel_list:
s = int(qsl[b])
e = int(qsl[b + 1])
q_len = max(0, e - s)
win = min(window, q_len)
w_list.append(int(win))
if win > 0:
q_chunks.append(query[e - win:e])
cu_q.append(cu_q[-1] + int(win))
if cu_q[-1] <= 0:
return None
q_packed = torch.cat(q_chunks, dim=0) if q_chunks else query[:0]
cu_seqlens_q = torch.tensor(cu_q, device=device, dtype=torch.int32)
w = torch.tensor(w_list, device=device, dtype=torch.int32)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel = prompt_lens.index_select(0, sel).to(torch.int32)
topk_keep_sel = topk_keep.index_select(0, sel).to(torch.int32)
cu_seqlens_k = torch.zeros((int(prompt_lens_sel.numel()) + 1, ),
device=device,
dtype=torch.int32)
if int(prompt_lens_sel.numel()) > 0:
cu_seqlens_k[1:] = torch.cumsum(prompt_lens_sel, dim=0)
block_table_sel = block_table.index_select(0, sel).to(torch.int32)
key_cache_view = paged_k_cache_view_for_triton_gather(
key_cache=key_cache,
block_size=int(block_size),
)
from vllm.v1.kv_compression.kv_cache_triton import (
gather_k_to_packed_triton)
k_packed = gather_k_to_packed_triton(
key_cache_view,
block_table_sel,
prompt_lens_sel,
cu_seqlens_k,
)
token_scores = snapkv_query_aware_token_scores(
query=q_packed,
key=k_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
window=w,
sm_scale=float(sm_scale),
)
idx_sorted, keep_len = _prompt_end_topk_keep_indices(
token_scores=token_scores,
prompt_lens=prompt_lens_sel,
topk_keep=topk_keep_sel,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
topk_keep_max=topk_keep_max,
)
return {
"req_indices": sel.to(torch.int32),
"idx_sorted": idx_sorted, # [B_sel, K_max] int32
"keep_len": keep_len, # [B_sel] int32
"prompt_lens": prompt_lens_sel, # [B_sel] int32
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment