Commit 0b5595ff authored by laibao's avatar laibao
Browse files

v1: chunked prefill 下延后 KV compression 到 prompt 结束

chunked prefill 时如果在 prefill 中途做 prompt KV 压缩,会导致下一段 prefill 只能看到被截断的历史(语义变化/质量塌陷)。
改为“scheme 3”:仅在每个请求的最后一个 prefill chunk 计算一次全局 prompt Top-K 索引并缓存;在第一次 decode 前执行一次性 in-place compaction。
补充 prompt keep/budget 计算辅助函数,并新增 Triton KV cache gather/compaction 实现。
parent 2676ad00
...@@ -190,6 +190,11 @@ class FlashAttentionMetadata: ...@@ -190,6 +190,11 @@ class FlashAttentionMetadata:
kv_compression_topk_budget: Optional[torch.Tensor] = None kv_compression_topk_budget: Optional[torch.Tensor] = None
# CPU-known max Top-K budget for this step (avoids device->host sync). # CPU-known max Top-K budget for this step (avoids device->host sync).
kv_compression_topk_budget_max: Optional[int] = None kv_compression_topk_budget_max: Optional[int] = None
# Chunked prefill: prompt-end one-shot scoring/Top-K (scheme 3).
kv_compression_prompt_end: Optional[torch.Tensor] = None # [B] bool
kv_compression_prompt_lens: Optional[torch.Tensor] = None # [B] int32
kv_compression_prompt_topk_keep: Optional[torch.Tensor] = None # [B] int32
kv_compression_prompt_topk_keep_max: Optional[int] = None
# Optional aot scheduling # Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None scheduler_metadata: Optional[torch.Tensor] = None
...@@ -300,6 +305,10 @@ class FlashAttentionMetadataBuilder( ...@@ -300,6 +305,10 @@ class FlashAttentionMetadataBuilder(
kv_compression_must_keep = None kv_compression_must_keep = None
kv_compression_topk_budget = None kv_compression_topk_budget = None
kv_compression_topk_budget_max: Optional[int] = None kv_compression_topk_budget_max: Optional[int] = None
kv_compression_prompt_end = None
kv_compression_prompt_lens = None
kv_compression_prompt_topk_keep = None
kv_compression_prompt_topk_keep_max: Optional[int] = None
if (envs.VLLM_ENABLE_KV_COMPRESSION if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.runner.kv_compression_needs_compaction): and self.runner.kv_compression_needs_compaction):
kv_compression_must_keep = self.runner.kv_compression_must_keep[: kv_compression_must_keep = self.runner.kv_compression_must_keep[:
...@@ -312,6 +321,17 @@ class FlashAttentionMetadataBuilder( ...@@ -312,6 +321,17 @@ class FlashAttentionMetadataBuilder(
self.runner.kv_compression_topk_budget_np[:num_reqs].max()) self.runner.kv_compression_topk_budget_np[:num_reqs].max())
else: else:
kv_compression_topk_budget_max = 0 kv_compression_topk_budget_max = 0
elif (envs.VLLM_ENABLE_KV_COMPRESSION
and self.runner.scheduler_config.chunked_prefill_enabled):
# Scheme 3: compute global prompt indices only on the last prefill
# chunk (per request), and perform the actual cache compaction
# before the first decode step.
if num_reqs > 0 and self.runner.kv_compression_prompt_end_np[:num_reqs].any():
kv_compression_prompt_end = self.runner.kv_compression_prompt_end[:num_reqs]
kv_compression_prompt_lens = self.runner.kv_compression_prompt_lens[:num_reqs]
kv_compression_prompt_topk_keep = self.runner.kv_compression_prompt_topk_keep[:num_reqs]
kv_compression_prompt_topk_keep_max = int(
self.runner.kv_compression_prompt_topk_keep_max or 0)
if self.aot_sliding_window is None: if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1) self.aot_sliding_window = (-1, -1)
...@@ -474,6 +494,10 @@ class FlashAttentionMetadataBuilder( ...@@ -474,6 +494,10 @@ class FlashAttentionMetadataBuilder(
kv_compression_must_keep=kv_compression_must_keep, kv_compression_must_keep=kv_compression_must_keep,
kv_compression_topk_budget=kv_compression_topk_budget, kv_compression_topk_budget=kv_compression_topk_budget,
kv_compression_topk_budget_max=kv_compression_topk_budget_max, kv_compression_topk_budget_max=kv_compression_topk_budget_max,
kv_compression_prompt_end=kv_compression_prompt_end,
kv_compression_prompt_lens=kv_compression_prompt_lens,
kv_compression_prompt_topk_keep=kv_compression_prompt_topk_keep,
kv_compression_prompt_topk_keep_max=kv_compression_prompt_topk_keep_max,
local_attn_metadata=local_attn_metadata, local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits, max_num_splits=max_num_splits,
...@@ -744,6 +768,32 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -744,6 +768,32 @@ class FlashAttentionImpl(AttentionImpl):
# a packed layout for the next step. # a packed layout for the next step.
if (envs.VLLM_ENABLE_KV_COMPRESSION if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.kv_sharing_target_layer_name is None): and self.kv_sharing_target_layer_name is None):
# Scheme 3 (chunked prefill): on the last prompt chunk, compute
# global prompt indices (score/topk) and cache them in the
# forward context for the model runner to consume before the
# first decode step. Do not write back here.
if (attn_metadata.kv_compression_prompt_end is not None
and attn_metadata.kv_compression_prompt_lens is not None
and attn_metadata.kv_compression_prompt_topk_keep is not None):
forward_context = get_forward_context()
payload = getattr(forward_context,
"_kv_compression_prompt_payload", None)
if payload is None:
payload = _compute_prompt_end_indices(
query=query[:num_actual_tokens],
key_cache=key_cache,
query_start_loc=attn_metadata.query_start_loc,
block_table=attn_metadata.block_table,
prompt_end=attn_metadata.kv_compression_prompt_end,
prompt_lens=attn_metadata.kv_compression_prompt_lens,
topk_keep=attn_metadata.kv_compression_prompt_topk_keep,
topk_keep_max=attn_metadata.kv_compression_prompt_topk_keep_max,
sm_scale=self.scale,
)
if payload is not None:
setattr(forward_context,
"_kv_compression_prompt_payload", payload)
dst = None dst = None
if (attn_metadata.kv_compression_must_keep is not None if (attn_metadata.kv_compression_must_keep is not None
and attn_metadata.kv_compression_topk_budget and attn_metadata.kv_compression_topk_budget
...@@ -1051,6 +1101,252 @@ def _snapkv_like_token_scores( ...@@ -1051,6 +1101,252 @@ def _snapkv_like_token_scores(
return get_tp_group().all_reduce(scores) return get_tp_group().all_reduce(scores)
def _prompt_end_topk_keep_indices(
*,
token_scores: torch.Tensor, # [T] float32
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32 (candidates only)
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
topk_keep_max: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device = token_scores.device
B = int(prompt_lens.numel())
if B == 0:
empty = torch.empty((0, 0), device=device, dtype=torch.int32)
return empty, torch.empty((0, ), device=device, dtype=torch.int32)
prompt_lens_i64 = prompt_lens.to(torch.long)
cu = torch.zeros((B + 1, ), device=device, dtype=torch.long)
cu[1:] = torch.cumsum(prompt_lens_i64, dim=0)
starts = cu[:B]
ends = cu[1:]
T = int(token_scores.numel())
if T == 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, torch.zeros((B, ), device=device, dtype=torch.int32)
token_idx = torch.arange(T, device=device, dtype=torch.long)
req_ids = torch.bucketize(token_idx, ends, right=True) # [T]
start_per_token = starts.index_select(0, req_ids)
pos_in_req = token_idx - start_per_token # [T]
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_prefix, 0))
suffix = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_suffix, 0))
suffix_start = (prompt_lens_i64 - suffix).clamp_min(0)
prefix_len_t = prefix_len.index_select(0, req_ids)
suffix_start_t = suffix_start.index_select(0, req_ids)
must_keep = (pos_in_req < prefix_len_t) | (pos_in_req >= suffix_start_t)
if keep_last_token:
last = (prompt_lens_i64 - 1).clamp_min(0)
last_t = last.index_select(0, req_ids)
must_keep |= pos_in_req == last_t
cand_counts = torch.zeros((B, ), device=device, dtype=torch.long)
cand_counts.scatter_add_(0, req_ids, (~must_keep).to(torch.long))
k_eff = torch.minimum(topk_keep.to(torch.long).clamp_min(0), cand_counts)
# CPU-known bound avoids a device->host sync; clamp for safety.
if topk_keep_max is None:
k_max = int(k_eff.max().item())
else:
k_max = int(topk_keep_max)
if k_max < 0:
k_max = 0
keep_mask = must_keep.clone()
if k_max > 0:
L_max = int(prompt_lens_i64.max().item())
masked_scores = token_scores.to(torch.float32).masked_fill(
must_keep, float("-inf"))
scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
linear = req_ids * L_max + pos_in_req
scores_flat[linear] = masked_scores
scores = scores_flat.view(B, L_max)
topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max]
col_mask = (torch.arange(k_max, device=device).unsqueeze(0) <
k_eff.unsqueeze(1))
global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B,k_max]
flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
flat_val = col_mask.reshape(-1).to(torch.int32)
tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
tmp.scatter_add_(0, flat_idx, flat_val)
keep_mask |= tmp > 0
keep_len = torch.zeros((B, ), device=device, dtype=torch.long)
keep_len.scatter_add_(0, req_ids, keep_mask.to(torch.long))
keep_max_len = int(keep_len.max().item()) if B > 0 else 0
if keep_max_len <= 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, keep_len.to(torch.int32)
# Stable, order-preserving index list using segment-local ranks.
keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T]
start_minus_1 = (starts - 1).clamp_min(0)
prefix_before_all = keep_prefix.index_select(0, start_minus_1)
prefix_before = torch.where(starts > 0, prefix_before_all,
torch.zeros_like(prefix_before_all))
prefix_before_t = prefix_before.index_select(0, req_ids)
local_rank = keep_prefix - prefix_before_t - 1 # [T]
idx_sorted = torch.zeros((B, keep_max_len), device=device, dtype=torch.int32)
lin_out = (req_ids * keep_max_len + local_rank).masked_select(keep_mask)
vals = pos_in_req.to(torch.int32).masked_select(keep_mask)
idx_sorted.view(-1).scatter_(0, lin_out, vals)
return idx_sorted, keep_len.to(torch.int32)
def _compute_prompt_end_indices(
*,
query: torch.Tensor, # [T, Hq, D] scheduled tokens for this step
key_cache: torch.Tensor, # layer KV cache view (platform-dependent)
query_start_loc: torch.Tensor, # [B+1] int32
block_table: torch.Tensor, # [B, max_blocks] int32
prompt_end: torch.Tensor, # [B] bool
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32
topk_keep_max: Optional[int],
sm_scale: float,
) -> Optional[dict[str, torch.Tensor]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device = query.device
if prompt_end.numel() == 0:
return None
sel = torch.nonzero(prompt_end, as_tuple=False).flatten()
if int(sel.numel()) == 0:
return None
window = int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW)
keep_last = bool(envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN)
protected_prefix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX)
protected_suffix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX)
# Build packed Q window (last `window` queries per selected request).
sel_list = sel.to(device="cpu", dtype=torch.int64).tolist()
qsl = query_start_loc.to(device="cpu", dtype=torch.int64).tolist()
q_chunks = []
cu_q = [0]
w_list = []
for b in sel_list:
s = int(qsl[b])
e = int(qsl[b + 1])
q_len = max(0, e - s)
win = min(window, q_len)
w_list.append(int(win))
if win > 0:
q_chunks.append(query[e - win:e])
cu_q.append(cu_q[-1] + int(win))
if cu_q[-1] <= 0:
return None
q_packed = torch.cat(q_chunks, dim=0) if q_chunks else query[:0]
cu_seqlens_q = torch.tensor(cu_q, device=device, dtype=torch.int32)
w = torch.tensor(w_list, device=device, dtype=torch.int32)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel = prompt_lens.index_select(0, sel).to(torch.int32)
topk_keep_sel = topk_keep.index_select(0, sel).to(torch.int32)
cu_seqlens_k = torch.zeros((int(prompt_lens_sel.numel()) + 1, ),
device=device,
dtype=torch.int32)
if int(prompt_lens_sel.numel()) > 0:
cu_seqlens_k[1:] = torch.cumsum(prompt_lens_sel, dim=0)
block_table_sel = block_table.index_select(0, sel).to(torch.int32)
if not current_platform.is_rocm():
# CUDA cache view: [num_blocks, block_size, H, D] -> [num_blocks, H, block_size, D]
key_cache_view = key_cache.permute(0, 2, 1, 3)
else:
key_cache_view = key_cache
from vllm.v1.attention.kv_compression.kv_cache_triton import (
gather_k_to_packed_triton)
k_packed = gather_k_to_packed_triton(
key_cache_view,
block_table_sel,
prompt_lens_sel,
cu_seqlens_k,
)
# SnapKV Triton scoring (token-shared via sum over KV heads).
from vllm.v1.attention.kv_compression.snapkv_triton import (
query_aware_key_scores)
try:
scores_per_head = query_aware_key_scores(
q=q_packed,
k=k_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
w=w,
sm_scale=float(sm_scale),
pool=False,
protect_last=False,
normalize=False,
)
token_scores = scores_per_head.sum(dim=1)
except Exception:
# Fallback: PyTorch reference scoring (slow but correctness-oriented).
Hq = q_packed.shape[1]
Hk = k_packed.shape[1]
D = q_packed.shape[2]
if Hq % Hk != 0:
raise
group = Hq // Hk
token_scores = torch.zeros((k_packed.shape[0], ),
device=device,
dtype=torch.float32)
for i in range(len(sel_list)):
qs = int(cu_q[i])
qe = int(cu_q[i + 1])
ks = int(cu_seqlens_k[i].item())
ke = int(cu_seqlens_k[i + 1].item())
if qe <= qs or ke <= ks:
continue
q_win = q_packed[qs:qe] # [win, Hq, D]
q_win = q_win.reshape(q_win.shape[0], Hk, group, D).mean(dim=2)
k_all = k_packed[ks:ke]
qh = q_win.permute(1, 0, 2).to(torch.float32)
kh = k_all.permute(1, 0, 2).to(torch.float32)
logits = torch.matmul(qh, kh.transpose(1, 2)) * float(sm_scale)
probs = torch.softmax(logits, dim=-1)
token_scores[ks:ke] = probs.sum(dim=1).sum(dim=0)
from vllm.distributed.parallel_state import get_tp_group
token_scores = get_tp_group().all_reduce(token_scores)
idx_sorted, keep_len = _prompt_end_topk_keep_indices(
token_scores=token_scores,
prompt_lens=prompt_lens_sel,
topk_keep=topk_keep_sel,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
topk_keep_max=topk_keep_max,
)
return {
"req_indices": sel.to(torch.int32),
"idx_sorted": idx_sorted, # [B_sel, K_max] int32
"keep_len": keep_len, # [B_sel] int32
"prompt_lens": prompt_lens_sel, # [B_sel] int32
}
def _topk_kv_compact_slot_mapping( def _topk_kv_compact_slot_mapping(
*, *,
token_scores: Optional[torch.Tensor], # [T] float32 token_scores: Optional[torch.Tensor], # [T] float32
......
from __future__ import annotations
from typing import Optional, Tuple
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import triton
import triton.language as tl
def _require_triton() -> None:
if not HAS_TRITON:
raise RuntimeError("Triton is not available.")
def _check_cuda(*tensors: torch.Tensor) -> None:
for t in tensors:
if not isinstance(t, torch.Tensor):
raise TypeError("Expected torch.Tensor inputs.")
if t.device.type != "cuda":
raise RuntimeError("Triton KV cache ops require CUDA/ROCm tensors.")
@triton.autotune(
configs=[
triton.Config({'BLOCK_T': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2),
],
key=["D"],
)
@triton.jit
def _gather_k_to_packed_kernel(
K_ptr,
out_ptr,
blk_ids_ptr,
req_blk_starts_ptr,
cu_seqlens_ptr,
seq_lens_ptr,
B,
H,
max_blocks,
block_size,
D,
sKb,
sKh,
sKt,
sKd,
so_t,
so_h,
so_d,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_t = tl.program_id(1)
pid_d = tl.program_id(2)
b = pid_bh // H
h = pid_bh % H
if b >= B:
return
seq_len = tl.load(seq_lens_ptr + b)
if seq_len <= 0:
return
t0 = pid_t * BLOCK_T
t_range = t0 + tl.arange(0, BLOCK_T)
t_mask = t_range < seq_len
d0 = pid_d * BLOCK_D
d_range = d0 + tl.arange(0, BLOCK_D)
d_mask = d_range < D
# Map logical token indices -> physical block ids.
blk = t_range // block_size
inb = t_range - blk * block_size
req_blk_start = tl.load(req_blk_starts_ptr + b)
gblk = req_blk_start + blk
# Guard against out-of-range block indices (should not happen when block_table
# covers the sequence length).
gblk_safe = tl.where(t_mask, gblk, 0)
bid = tl.load(blk_ids_ptr + gblk_safe, mask=t_mask, other=0)
# Source: key cache layout [num_blocks, H, block_size, D]
src_base = K_ptr + bid[:, None] * sKb + h * sKh + inb[:, None] * sKt
src_ptrs = src_base + d_range[None, :] * sKd
# Destination: packed output layout [T, H, D]
out_start = tl.load(cu_seqlens_ptr + b)
dst_base = out_ptr + (out_start + t_range)[:, None] * so_t + h * so_h
dst_ptrs = dst_base + d_range[None, :] * so_d
tile = tl.load(src_ptrs, mask=(t_mask[:, None] & d_mask[None, :]), other=0)
tl.store(dst_ptrs, tile, mask=(t_mask[:, None] & d_mask[None, :]))
@torch.inference_mode()
def gather_k_to_packed_triton(
key_cache: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
cu_seqlens: torch.Tensor,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Gather a block-wise KV key cache into a packed [T, H, D] tensor.
Expected layouts:
- key_cache: [num_blocks, H, block_size, D]
- block_table: [B, max_blocks] int32 physical block ids
- seq_lens: [B] int32 logical lengths (tokens) to gather
- cu_seqlens: [B+1] int32 cumulative offsets into the packed output
"""
_require_triton()
_check_cuda(key_cache, block_table, seq_lens, cu_seqlens)
if key_cache.ndim != 4:
raise ValueError("key_cache must be a 4D tensor [num_blocks, H, Tb, D].")
if block_table.ndim != 2:
raise ValueError("block_table must be 2D [B, max_blocks].")
if seq_lens.ndim != 1:
raise ValueError("seq_lens must be 1D [B].")
if cu_seqlens.ndim != 1:
raise ValueError("cu_seqlens must be 1D [B+1].")
device = key_cache.device
B = int(seq_lens.numel())
if B == 0:
return torch.empty((0, int(key_cache.shape[1]), int(key_cache.shape[3])),
device=device,
dtype=key_cache.dtype)
H = int(key_cache.shape[1])
block_size = int(key_cache.shape[2])
D = int(key_cache.shape[3])
max_blocks = int(block_table.shape[1])
seq_lens_i32 = seq_lens.to(device=device, dtype=torch.int32)
cu_i32 = cu_seqlens.to(device=device, dtype=torch.int32)
total_tokens = int(cu_i32[-1].item()) if cu_i32.numel() > 0 else 0
if out is None:
out = torch.empty((total_tokens, H, D), device=device, dtype=key_cache.dtype)
else:
if out.shape != (total_tokens, H, D):
raise ValueError(
f"out has shape {tuple(out.shape)}, expected {(total_tokens, H, D)}."
)
blk_ids = block_table.to(device=device, dtype=torch.int32).reshape(-1)
req_starts = (torch.arange(B, device=device, dtype=torch.int32) * max_blocks)
sKb, sKh, sKt, sKd = [int(s) for s in key_cache.stride()]
so_t, so_h, so_d = [int(s) for s in out.stride()]
L_max = int(seq_lens_i32.max().item()) if B > 0 else 0
if total_tokens == 0 or L_max == 0 or D == 0 or H == 0:
return out
# Use the smallest tile sizes across autotune configs to guarantee coverage
# even when the selected config uses smaller blocks.
grid = (
B * H,
triton.cdiv(L_max, 128),
triton.cdiv(D, 64),
)
_gather_k_to_packed_kernel[grid](
key_cache,
out,
blk_ids,
req_starts,
cu_i32,
seq_lens_i32,
B,
H,
max_blocks,
block_size,
D,
sKb,
sKh,
sKt,
sKd,
so_t,
so_h,
so_d,
)
return out
@triton.autotune(
configs=[
triton.Config({'BLOCK_T': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 512, 'BLOCK_D': 64}, num_warps=8, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2),
],
key=['K_max', 'Dk'],
)
@triton.jit
def _front_compact_inplace_fa_k_kernel(
K_ptr,
blk_ids_ptr,
req_blk_starts_ptr,
idx_ptr,
keep_ptr,
B,
H,
K_max,
block_size,
Dk,
sKb,
sKh,
sKt,
sKd,
si_b,
si_h,
si_k,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_d = tl.program_id(1)
b = pid_bh // H
h = pid_bh % H
if b >= B:
return
d0 = pid_d * BLOCK_D
d_range = d0 + tl.arange(0, BLOCK_D)
d_mask = d_range < Dk
d_safe = tl.where(d_mask, d_range, 0)
keep_b = tl.load(keep_ptr + b)
if keep_b <= 0:
return
req_blk_start = tl.load(req_blk_starts_ptr + b)
k0 = 0
while k0 < keep_b:
k_range = k0 + tl.arange(0, BLOCK_T)
k_mask = (k_range < K_max) & (k_range < keep_b)
k_safe = tl.where(k_mask, k_range, 0)
idx_base = idx_ptr + b * si_b + h * si_h + k_safe * si_k
t_src = tl.load(idx_base, mask=k_mask, other=0)
# No-op copies (src == dst) can be skipped safely because idx_sorted is
# ascending, so we always copy from later/equal positions to earlier.
t_dst = k_safe
copy_mask = k_mask & (t_src != t_dst)
blk_src = t_src // block_size
inb_src = t_src % block_size
gblk_src = req_blk_start + blk_src
bid_src = tl.load(blk_ids_ptr + gblk_src, mask=copy_mask, other=0)
blk_dst = t_dst // block_size
inb_dst = t_dst % block_size
gblk_dst = req_blk_start + blk_dst
bid_dst = tl.load(blk_ids_ptr + gblk_dst, mask=copy_mask, other=0)
src_base = K_ptr + bid_src[:, None] * sKb + h * sKh + inb_src[:, None] * sKt
src_ptrs = src_base + d_safe[None, :] * sKd
dst_base = K_ptr + bid_dst[:, None] * sKb + h * sKh + inb_dst[:, None] * sKt
dst_ptrs = dst_base + d_safe[None, :] * sKd
tile = tl.load(src_ptrs,
mask=(copy_mask[:, None] & d_mask[None, :]),
other=0)
tl.store(dst_ptrs, tile, mask=(copy_mask[:, None] & d_mask[None, :]))
k0 += BLOCK_T
@triton.autotune(
configs=[
triton.Config({'BLOCK_T': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 512, 'BLOCK_D': 64}, num_warps=8, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2),
],
key=['K_max', 'Dv'],
)
@triton.jit
def _front_compact_inplace_fa_v_kernel(
V_ptr,
blk_ids_ptr,
req_blk_starts_ptr,
idx_ptr,
keep_ptr,
B,
H,
K_max,
block_size,
Dv,
sv_b,
sv_h,
sv_d,
sv_t,
si_b,
si_h,
si_k,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_d = tl.program_id(1)
b = pid_bh // H
h = pid_bh % H
if b >= B:
return
d0 = pid_d * BLOCK_D
d_range = d0 + tl.arange(0, BLOCK_D)
d_mask = d_range < Dv
d_safe = tl.where(d_mask, d_range, 0)
keep_b = tl.load(keep_ptr + b)
if keep_b <= 0:
return
req_blk_start = tl.load(req_blk_starts_ptr + b)
k0 = 0
while k0 < keep_b:
k_range = k0 + tl.arange(0, BLOCK_T)
k_mask = (k_range < K_max) & (k_range < keep_b)
k_safe = tl.where(k_mask, k_range, 0)
idx_base = idx_ptr + b * si_b + h * si_h + k_safe * si_k
t_src = tl.load(idx_base, mask=k_mask, other=0)
t_dst = k_safe
copy_mask = k_mask & (t_src != t_dst)
blk_src = t_src // block_size
inb_src = t_src % block_size
gblk_src = req_blk_start + blk_src
bid_src = tl.load(blk_ids_ptr + gblk_src, mask=copy_mask, other=0)
blk_dst = t_dst // block_size
inb_dst = t_dst % block_size
gblk_dst = req_blk_start + blk_dst
bid_dst = tl.load(blk_ids_ptr + gblk_dst, mask=copy_mask, other=0)
# value layout: [num_blocks, H, Dv, block_size]
v_src_base = V_ptr + bid_src[:, None] * sv_b + h * sv_h + d_safe[None, :] * sv_d
v_src_ptrs = v_src_base + inb_src[:, None] * sv_t
v_dst_base = V_ptr + bid_dst[:, None] * sv_b + h * sv_h + d_safe[None, :] * sv_d
v_dst_ptrs = v_dst_base + inb_dst[:, None] * sv_t
tile = tl.load(v_src_ptrs,
mask=(copy_mask[:, None] & d_mask[None, :]),
other=0)
tl.store(v_dst_ptrs, tile, mask=(copy_mask[:, None] & d_mask[None, :]))
k0 += BLOCK_T
@torch.inference_mode()
def front_compact_inplace_fa_triton(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_table: torch.Tensor,
idx_sorted: torch.Tensor,
keep: torch.Tensor,
) -> None:
"""In-place front compaction for FlashAttention KV cache.
Moves selected time indices to the front [0..keep[b]) per request for both
key_cache and value_cache in-place.
Expected layouts:
- key_cache: [num_blocks, H, block_size, Dk]
- value_cache: [num_blocks, H, Dv, block_size]
- block_table: [B, max_blocks] int32 physical block ids
- idx_sorted: [B, K] int32 or [B, H, K] int32 (ascending indices)
- keep: [B] int32 (<= K), number of kept tokens per request
"""
_require_triton()
_check_cuda(key_cache, value_cache, block_table, idx_sorted, keep)
if key_cache.ndim != 4 or value_cache.ndim != 4:
raise ValueError("key_cache/value_cache must be 4D tensors.")
if block_table.ndim != 2:
raise ValueError("block_table must be 2D [B, max_blocks].")
if idx_sorted.ndim not in (2, 3):
raise ValueError("idx_sorted must be 2D [B,K] or 3D [B,H,K].")
if keep.ndim != 1:
raise ValueError("keep must be 1D [B].")
device = key_cache.device
B = int(block_table.shape[0])
if B == 0:
return
H = int(key_cache.shape[1])
block_size = int(key_cache.shape[2])
Dk = int(key_cache.shape[3])
Dv = int(value_cache.shape[2])
if idx_sorted.ndim == 2:
idx_sorted = idx_sorted[:, None, :].expand(-1, H, -1)
K_max = int(idx_sorted.shape[2])
if K_max == 0:
return
blk_ids = block_table.to(device=device, dtype=torch.int32).reshape(-1)
max_blocks = int(block_table.shape[1])
req_starts = (torch.arange(B, device=device, dtype=torch.int32) * max_blocks)
idx_i32 = idx_sorted.to(device=device, dtype=torch.int32)
keep_i32 = keep.to(device=device, dtype=torch.int32)
sKb, sKh, sKt, sKd = [int(s) for s in key_cache.stride()]
sv_b, sv_h, sv_d, sv_t = [int(s) for s in value_cache.stride()]
si_b, si_h, si_k = [int(s) for s in idx_i32.stride()]
if Dk > 0:
grid_k = (
B * H,
triton.cdiv(Dk, 64),
)
_front_compact_inplace_fa_k_kernel[grid_k](
key_cache,
blk_ids,
req_starts,
idx_i32,
keep_i32,
B,
H,
K_max,
block_size,
Dk,
sKb,
sKh,
sKt,
sKd,
si_b,
si_h,
si_k,
)
if Dv > 0:
grid_v = (
B * H,
triton.cdiv(Dv, 64),
)
_front_compact_inplace_fa_v_kernel[grid_v](
value_cache,
blk_ids,
req_starts,
idx_i32,
keep_i32,
B,
H,
K_max,
block_size,
Dv,
sv_b,
sv_h,
sv_d,
sv_t,
si_b,
si_h,
si_k,
)
def make_fa_cache_view(
*,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return (K_view, V_view) in the canonical FA compaction layout.
- K_view: [num_blocks, H, block_size, D]
- V_view: [num_blocks, H, D, block_size]
"""
if key_cache.ndim != 4 or value_cache.ndim != 4:
raise ValueError("key_cache/value_cache must be 4D tensors.")
# ROCm path (FlashAttention v1): K=[B,H,T,D] and V=[B,H,D,T]
if (value_cache.shape[3] == key_cache.shape[2]
and value_cache.shape[2] == key_cache.shape[3]):
k_view = key_cache
v_view = value_cache
else:
# CUDA path: K=[B,T,H,D] and V=[B,T,H,D]
k_view = key_cache.permute(0, 2, 1, 3)
v_view = value_cache.permute(0, 2, 3, 1)
return k_view, v_view
...@@ -34,7 +34,8 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -34,7 +34,8 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.kv_compression.budget import (compute_topk_budget_step, from vllm.v1.kv_compression.budget import (compute_prompt_keep_len,
compute_topk_budget_step,
count_prompt_must_keep_in_range) count_prompt_must_keep_in_range)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs from vllm import envs
...@@ -1185,6 +1186,36 @@ class Scheduler(SchedulerInterface): ...@@ -1185,6 +1186,36 @@ class Scheduler(SchedulerInterface):
end_pos = request.num_computed_tokens end_pos = request.num_computed_tokens
prompt_end = request.num_prompt_tokens prompt_end = request.num_prompt_tokens
# Chunked prefill: do not change the prompt KV length mid-prefill.
# Otherwise, the next prefill chunk would attend to a truncated
# history (semantic change / quality collapse). Instead, keep the
# full prompt KV until the prompt is fully ingested, then apply a
# one-shot prompt compaction before decode.
if self.scheduler_config.chunked_prefill_enabled:
if start_pos >= prompt_end:
# Decode token(s): keep all.
request.num_kv_tokens += num_scheduled_token
continue
if end_pos < prompt_end:
# Prompt is still being ingested: keep all tokens for now.
request.num_kv_tokens += num_scheduled_token
continue
# This step finishes the prompt (and may include decode tokens
# in rare cases). Apply the final prompt compression length.
kept_prompt_total = compute_prompt_keep_len(
prompt_len=prompt_end,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
kept_decode = max(0, end_pos - max(start_pos, prompt_end))
request.num_kv_tokens = kept_prompt_total + kept_decode
continue
# Decode token(s): keep all. # Decode token(s): keep all.
decode_start = max(start_pos, prompt_end) decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start) kept_decode = max(0, end_pos - decode_start)
......
...@@ -186,3 +186,63 @@ def compute_topk_budget_step( ...@@ -186,3 +186,63 @@ def compute_topk_budget_step(
step_keep = bud_upto_end - bud_upto_start step_keep = bud_upto_end - bud_upto_start
return _clamp_int(step_keep, 0, step_total) return _clamp_int(step_keep, 0, step_total)
def compute_prompt_topk_keep_total(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute how many *candidate* prompt tokens to keep in total.
This excludes tokens in the protected prefix/suffix region (and optionally
the last prompt token) which are always kept.
"""
total = _candidate_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
if total <= 0:
return 0
return _candidate_keep_total(
candidate_total=total,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
def compute_prompt_keep_len(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute total kept prompt tokens after compression (must-keep + Top-K)."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
kept_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_len,
start_pos=0,
end_pos=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
kept_topk = compute_prompt_topk_keep_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
return _clamp_int(kept_must_keep + kept_topk, 0, prompt_len)
...@@ -52,6 +52,12 @@ class CachedRequestState: ...@@ -52,6 +52,12 @@ class CachedRequestState:
repr=False, repr=False,
compare=False) compare=False)
# Chunked prefill (scheme 3): cached prompt compaction plan.
# Computed on the last prompt chunk; applied before the first decode step.
kv_compression_prompt_idx_sorted: Optional[torch.Tensor] = None # [K] int32
kv_compression_prompt_keep_len: Optional[int] = None
kv_compression_prompt_prompt_len: Optional[int] = None
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids) self.num_prompt_tokens = len(self.prompt_token_ids)
......
...@@ -55,7 +55,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget ...@@ -55,7 +55,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec, KVCacheConfig, KVCacheSpec, MambaSpec,
SlidingWindowSpec) SlidingWindowSpec)
from vllm.v1.kv_compression.budget import compute_topk_budget_step from vllm.v1.kv_compression.budget import (compute_prompt_topk_keep_total,
compute_topk_budget_step)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
...@@ -368,6 +369,49 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -368,6 +369,49 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device=self.device, device=self.device,
) )
# Chunked-prefill prompt-end KV compression metadata (scheme 3).
# Per-request: whether this step finishes the prompt and should compute
# global prompt indices (score/topk) for a one-shot compaction.
self.kv_compression_prompt_end_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_prompt_end_np = self.kv_compression_prompt_end_cpu.numpy()
self.kv_compression_prompt_end = torch.zeros(
self.max_num_reqs,
dtype=torch.bool,
device=self.device,
)
# Per-request: prompt length (tokens) and Top-K keep count among prompt
# candidates (excluding protected prefix/suffix).
self.kv_compression_prompt_lens_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_prompt_lens_np = self.kv_compression_prompt_lens_cpu.numpy()
self.kv_compression_prompt_lens = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
self.kv_compression_prompt_topk_keep_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_prompt_topk_keep_np = self.kv_compression_prompt_topk_keep_cpu.numpy()
self.kv_compression_prompt_topk_keep = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
self.kv_compression_prompt_topk_keep_max: Optional[int] = None
# Layer pairings for cross-layer KV sharing. # Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it # If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values # means this layer will perform attention using the keys and values
...@@ -728,43 +772,32 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -728,43 +772,32 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
must_keep_np = self.kv_compression_must_keep_np[ if self.scheduler_config.chunked_prefill_enabled:
:total_num_scheduled_tokens] # Scheme 3: with chunked prefill, defer compaction until after
must_keep_np.fill(False) # the full prompt is ingested. Otherwise, the next prefill chunk
topk_budget_np = self.kv_compression_topk_budget_np[:num_reqs] # would attend to a truncated history and quality can collapse.
topk_budget_np.fill(0) prompt_end_np = self.kv_compression_prompt_end_np[:num_reqs]
prompt_end_np.fill(False)
for req_idx in range(num_reqs): prompt_lens_np = self.kv_compression_prompt_lens_np[:num_reqs]
qlen = int(num_scheduled_tokens[req_idx]) prompt_lens_np.fill(0)
if qlen <= 0: topk_keep_np = self.kv_compression_prompt_topk_keep_np[:num_reqs]
continue topk_keep_np.fill(0)
start = 0 if req_idx == 0 else int(cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx]) for req_idx in range(num_reqs):
assert end - start == qlen qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
base_pos = int( continue
self.input_batch.num_computed_tokens_cpu[req_idx]) base_pos = int(self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx]) prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen end_pos = base_pos + qlen
pos = base_pos + np.arange(qlen, dtype=np.int64) ends_prompt = (base_pos < prompt_len) and (end_pos >= prompt_len)
if not ends_prompt:
prompt_mask = pos < prompt_len continue
# Decode tokens are always kept.
must_keep = ~prompt_mask prompt_end_np[req_idx] = True
prompt_lens_np[req_idx] = prompt_len
if np.any(prompt_mask): topk_keep_np[req_idx] = compute_prompt_topk_keep_total(
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len, prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix, protected_prefix=protected_prefix,
protected_suffix=protected_suffix, protected_suffix=protected_suffix,
keep_last_token=keep_last, keep_last_token=keep_last,
...@@ -772,13 +805,62 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -772,13 +805,62 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
prompt_budget=prompt_budget, prompt_budget=prompt_budget,
) )
must_keep_np[start:end] = must_keep self.kv_compression_prompt_topk_keep_max = int(
topk_keep_np[:num_reqs].max()) if num_reqs > 0 else 0
self.kv_compression_needs_compaction = False
else:
must_keep_np = self.kv_compression_must_keep_np[
:total_num_scheduled_tokens]
must_keep_np.fill(False)
topk_budget_np = self.kv_compression_topk_budget_np[:num_reqs]
topk_budget_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
start = 0 if req_idx == 0 else int(
cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx])
assert end - start == qlen
base_pos = int(self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
pos = base_pos + np.arange(qlen, dtype=np.int64)
prompt_mask = pos < prompt_len
# Decode tokens are always kept.
must_keep = ~prompt_mask
if np.any(prompt_mask):
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
must_keep_np[start:end] = must_keep
# Decode-only fast path: if all scheduled tokens are unconditionally # Decode-only fast path: if all scheduled tokens are
# kept and there is no Top-K budget, KV compaction is a no-op and we # unconditionally kept and there is no Top-K budget, KV
# can skip score/topk/dst entirely in the attention backend. # compaction is a no-op and we can skip score/topk/dst entirely
self.kv_compression_needs_compaction = (not must_keep_np.all()) or ( # in the attention backend.
topk_budget_np > 0).any() self.kv_compression_needs_compaction = (not must_keep_np.all(
)) or (topk_budget_np > 0).any()
else: else:
self.kv_compression_needs_compaction = False self.kv_compression_needs_compaction = False
...@@ -860,14 +942,28 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -860,14 +942,28 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True) non_blocking=True)
if use_kv_compression: if use_kv_compression:
self.kv_compression_must_keep[:total_num_scheduled_tokens].copy_( if self.scheduler_config.chunked_prefill_enabled:
self.kv_compression_must_keep_cpu[:total_num_scheduled_tokens], self.kv_compression_prompt_end[:num_reqs].copy_(
non_blocking=True, self.kv_compression_prompt_end_cpu[:num_reqs],
) non_blocking=True,
self.kv_compression_topk_budget[:num_reqs].copy_( )
self.kv_compression_topk_budget_cpu[:num_reqs], self.kv_compression_prompt_lens[:num_reqs].copy_(
non_blocking=True, self.kv_compression_prompt_lens_cpu[:num_reqs],
) non_blocking=True,
)
self.kv_compression_prompt_topk_keep[:num_reqs].copy_(
self.kv_compression_prompt_topk_keep_cpu[:num_reqs],
non_blocking=True,
)
elif self.kv_compression_needs_compaction:
self.kv_compression_must_keep[:total_num_scheduled_tokens].copy_(
self.kv_compression_must_keep_cpu[:total_num_scheduled_tokens],
non_blocking=True,
)
self.kv_compression_topk_budget[:num_reqs].copy_(
self.kv_compression_topk_budget_cpu[:num_reqs],
non_blocking=True,
)
# Fill unused with -1. Needed for reshape_and_cache # Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0) self.seq_lens[num_reqs:].fill_(0)
...@@ -3350,6 +3446,155 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3350,6 +3446,155 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
return (attn_metadata, attention_cuda_graphs, logits_indices, return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata, num_scheduled_tokens) spec_decode_metadata, num_scheduled_tokens)
def _stash_kv_compression_prompt_payload(self) -> None:
"""Persist prompt-end compaction indices from the forward context."""
if (not envs.VLLM_ENABLE_KV_COMPRESSION
or not self.scheduler_config.chunked_prefill_enabled):
return
forward_context = get_forward_context()
payload = getattr(forward_context, "_kv_compression_prompt_payload",
None)
if payload is None:
return
req_indices = payload.get("req_indices")
idx_sorted = payload.get("idx_sorted")
keep_len = payload.get("keep_len")
prompt_lens = payload.get("prompt_lens")
if (req_indices is None or idx_sorted is None or keep_len is None
or prompt_lens is None):
return
req_indices_cpu = req_indices.to(device="cpu",
dtype=torch.int64).tolist()
keep_cpu = keep_len.to(device="cpu", dtype=torch.int64).tolist()
prompt_cpu = prompt_lens.to(device="cpu", dtype=torch.int64).tolist()
for i, b in enumerate(req_indices_cpu):
if b < 0 or b >= len(self.input_batch.req_ids):
continue
req_id = self.input_batch.req_ids[b]
if req_id is None:
continue
rs = self.requests.get(req_id)
if rs is None:
continue
rs.kv_compression_prompt_idx_sorted = idx_sorted[i]
rs.kv_compression_prompt_keep_len = int(keep_cpu[i])
rs.kv_compression_prompt_prompt_len = int(prompt_cpu[i])
def _maybe_apply_kv_compression_prompt_compaction(self) -> None:
"""Apply one-shot prompt KV compaction before the first decode step."""
if (not envs.VLLM_ENABLE_KV_COMPRESSION
or not self.scheduler_config.chunked_prefill_enabled):
return
pending_req_ids: list[str] = []
for req_id in self.input_batch.req_ids:
if req_id is None:
continue
rs = self.requests.get(req_id)
if rs is None:
continue
if rs.kv_compression_prompt_idx_sorted is None:
continue
# Only apply once the prompt is fully ingested (decode stage).
if rs.num_computed_tokens < rs.num_prompt_tokens:
continue
pending_req_ids.append(req_id)
if not pending_req_ids:
return
device = self.device
pending_states: list[tuple[str, torch.Tensor, int]] = []
for req_id in pending_req_ids:
rs = self.requests[req_id]
keep = rs.kv_compression_prompt_keep_len
idx = rs.kv_compression_prompt_idx_sorted
if keep is None or idx is None:
continue
keep_i = int(keep)
if keep_i <= 0:
# No prompt tokens kept; clear and skip.
rs.kv_compression_prompt_idx_sorted = None
rs.kv_compression_prompt_keep_len = None
rs.kv_compression_prompt_prompt_len = None
continue
pending_states.append((req_id, idx, keep_i))
if not pending_states:
return
B = len(pending_states)
keep_list = [k for _, _, k in pending_states]
K_max = max(keep_list)
idx_batch = torch.zeros((B, K_max), device=device, dtype=torch.int32)
for i, (_, row, k) in enumerate(pending_states):
idx_batch[i, :k] = row[:k].to(device=device, dtype=torch.int32)
keep_tensor = torch.tensor(keep_list,
device=device,
dtype=torch.int32)
from vllm.v1.attention.kv_compression.kv_cache_triton import (
front_compact_inplace_fa_triton, make_fa_cache_view)
# Apply compaction to every attention layer's KV cache in-place.
for group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
max_blocks = 0
for req_id, _, _ in pending_states:
rs = self.requests[req_id]
if group_id >= len(rs.block_ids):
continue
max_blocks = max(max_blocks, len(rs.block_ids[group_id]))
if max_blocks == 0:
continue
block_table_cpu = torch.zeros((B, max_blocks),
dtype=torch.int32,
device="cpu")
for i, (req_id, _, _) in enumerate(pending_states):
rs = self.requests[req_id]
if group_id >= len(rs.block_ids):
continue
ids = rs.block_ids[group_id]
if ids:
block_table_cpu[i, :len(ids)] = torch.tensor(
ids, dtype=torch.int32, device="cpu")
block_table = block_table_cpu.to(device=device, non_blocking=True)
for layer_name in kv_cache_group_spec.layer_names:
layer_index = self._extract_layer_index(layer_name)
if layer_index >= len(self.kv_caches):
continue
kv_cache = self.kv_caches[layer_index]
if not current_platform.is_rocm():
if not isinstance(kv_cache, torch.Tensor):
continue
key_cache, value_cache = kv_cache.unbind(0)
else:
if not isinstance(kv_cache, (tuple, list)) or len(kv_cache) != 2:
continue
key_cache, value_cache = kv_cache
k_view, v_view = make_fa_cache_view(key_cache=key_cache,
value_cache=value_cache)
front_compact_inplace_fa_triton(
k_view,
v_view,
block_table,
idx_batch,
keep_tensor,
)
# Clear pending state after successful compaction.
for req_id, _, _ in pending_states:
rs = self.requests.get(req_id)
if rs is None:
continue
rs.kv_compression_prompt_idx_sorted = None
rs.kv_compression_prompt_keep_len = None
rs.kv_compression_prompt_prompt_len = None
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
...@@ -3446,6 +3691,10 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3446,6 +3691,10 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# If attention doesn't support CUDA Graphs for this batch, but we # If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
# Chunked prefill (scheme 3): apply one-shot prompt KV compaction before
# the first decode step writes/reads KV at the compressed positions.
self._maybe_apply_kv_compression_prompt_compaction()
if envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS: if envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
model_output, finished_sending, finished_recving = \ model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens, tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
...@@ -3473,6 +3722,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3473,6 +3722,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
self._stash_kv_compression_prompt_payload()
self.maybe_wait_for_kv_save() self.maybe_wait_for_kv_save()
finished_sending, finished_recving = ( finished_sending, finished_recving = (
...@@ -3498,6 +3748,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3498,6 +3748,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
self._stash_kv_compression_prompt_payload()
self.maybe_wait_for_kv_save() self.maybe_wait_for_kv_save()
finished_sending, finished_recving = ( finished_sending, finished_recving = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment