Commit dbcb0376 authored by laibao's avatar laibao
Browse files

feat(kvpress): Runner 接入 KV 位置与注意力元数据

parent 3d4f8753
...@@ -304,6 +304,8 @@ class CommonAttentionMetadata: ...@@ -304,6 +304,8 @@ class CommonAttentionMetadata:
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading # TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int num_actual_tokens: int
"""Total number of tokens in batch""" """Total number of tokens in batch"""
num_unpadded_tokens: int | None = None
"""Number of scheduled tokens excluding padding, if known."""
max_query_len: int max_query_len: int
"""Longest query in batch""" """Longest query in batch"""
max_seq_len: int max_seq_len: int
...@@ -332,6 +334,16 @@ class CommonAttentionMetadata: ...@@ -332,6 +334,16 @@ class CommonAttentionMetadata:
_num_computed_tokens_cache: torch.Tensor | None = None _num_computed_tokens_cache: torch.Tensor | None = None
# KV compression metadata (experimental, v1 paged attention only).
kv_compression_must_keep: torch.Tensor | None = None
kv_compression_topk_budget: torch.Tensor | None = None
kv_compression_topk_budget_max: int | None = None
kv_compression_prompt_end: torch.Tensor | None = None
kv_compression_prompt_lens: torch.Tensor | None = None
kv_compression_prompt_topk_keep: torch.Tensor | None = None
kv_compression_prompt_topk_keep_max: int | None = None
def batch_size(self) -> int: def batch_size(self) -> int:
return self.seq_lens.shape[0] return self.seq_lens.shape[0]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any, Optional
import torch
_PROMPT_PAYLOAD_ATTR = "_kv_compression_prompt_payload"
_COMPACT_SLOTS_ATTR = "_kv_compression_compact_slots"
_COMPACT_SLOTS_BY_LAYER_ATTR = "_kv_compression_compact_slots_by_layer"
def get_kv_compression_prompt_payload(
forward_context: Any,
) -> Optional[dict[str, torch.Tensor]]:
return getattr(forward_context, _PROMPT_PAYLOAD_ATTR, None)
def set_kv_compression_prompt_payload(
forward_context: Any,
payload: dict[str, torch.Tensor],
) -> None:
setattr(forward_context, _PROMPT_PAYLOAD_ATTR, payload)
def _kv_compression_layer_key(layer: Any) -> str:
layer_name = getattr(layer, "layer_name", None)
if layer_name is None:
layer_name = str(id(layer))
return str(layer_name)
def get_kv_compression_compact_slots(
forward_context: Any,
*,
per_layer_topk: bool,
layer: Any,
) -> Optional[torch.Tensor]:
if per_layer_topk:
dst_by_layer = getattr(forward_context, _COMPACT_SLOTS_BY_LAYER_ATTR,
None)
if dst_by_layer is None:
return None
return dst_by_layer.get(_kv_compression_layer_key(layer))
return getattr(forward_context, _COMPACT_SLOTS_ATTR, None)
def set_kv_compression_compact_slots(
forward_context: Any,
*,
per_layer_topk: bool,
layer: Any,
dst: torch.Tensor,
) -> None:
if per_layer_topk:
dst_by_layer = getattr(forward_context, _COMPACT_SLOTS_BY_LAYER_ATTR,
None)
if dst_by_layer is None:
dst_by_layer = {}
setattr(forward_context, _COMPACT_SLOTS_BY_LAYER_ATTR, dst_by_layer)
dst_by_layer[_kv_compression_layer_key(layer)] = dst
else:
setattr(forward_context, _COMPACT_SLOTS_ATTR, dst)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional
import torch
import vllm.envs as envs
@dataclass
class KVCompressionAttentionMetadata:
"""Per-batch KV compression metadata consumed by attention backends."""
must_keep: Optional[torch.Tensor] = None
topk_budget: Optional[torch.Tensor] = None
topk_budget_max: Optional[int] = None
prompt_end: Optional[torch.Tensor] = None
prompt_lens: Optional[torch.Tensor] = None
prompt_topk_keep: Optional[torch.Tensor] = None
prompt_topk_keep_max: Optional[int] = None
def build_kv_compression_attn_metadata(
*,
runner: Any,
num_reqs: int,
num_actual_tokens: int,
) -> KVCompressionAttentionMetadata:
"""Build KV compression metadata for one attention step.
This helper keeps backend code thin and centralizes the logic for selecting
between per-step compaction (scheme 1/2) and prompt-end one-shot scoring
(scheme 3).
"""
meta = KVCompressionAttentionMetadata()
if not envs.VLLM_ENABLE_KV_COMPRESSION:
return meta
# Scheme 1/2: compute compaction destinations every step.
if getattr(runner, "kv_compression_needs_compaction", False):
meta.must_keep = runner.kv_compression_must_keep[:num_actual_tokens]
meta.topk_budget = runner.kv_compression_topk_budget[:num_reqs]
# Avoid device->host sync by reading from the CPU staging buffer.
if num_reqs > 0:
meta.topk_budget_max = int(
runner.kv_compression_topk_budget_np[:num_reqs].max())
else:
meta.topk_budget_max = 0
return meta
# Scheme 3: compute global prompt indices only on the last prefill chunk,
# and perform the actual cache compaction before the first decode step.
scheduler_config = getattr(runner, "scheduler_config", None)
if scheduler_config is None or not getattr(scheduler_config,
"enable_chunked_prefill",
False):
return meta
if num_reqs <= 0:
return meta
if not runner.kv_compression_prompt_end_np[:num_reqs].any():
return meta
meta.prompt_end = runner.kv_compression_prompt_end[:num_reqs]
meta.prompt_lens = runner.kv_compression_prompt_lens[:num_reqs]
meta.prompt_topk_keep = runner.kv_compression_prompt_topk_keep[:num_reqs]
meta.prompt_topk_keep_max = int(
getattr(runner, "kv_compression_prompt_topk_keep_max", 0) or 0)
return meta
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any, Optional
import torch
def init_kv_compression_runner_buffers(
*,
runner: Any,
max_num_tokens: int,
max_num_reqs: int,
device: torch.device,
pin_memory: bool,
) -> None:
"""Initialize per-runner buffers used by KV compression.
This helper keeps `gpu_model_runner.py` focused on orchestration while
preserving the existing attribute-based access patterns.
"""
# KV positions are decoupled from logical positions when KV compression is
# enabled. Keep a separate buffer to avoid recomputing or overwriting the
# logical `positions_np` (used for RoPE / token lookup).
runner.kv_positions_cpu = torch.zeros(
max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_positions_np = runner.kv_positions_cpu.numpy()
# KV compression metadata buffers (used by the "topk" policy).
# Per-token: whether this scheduled token must be kept in KV cache.
runner.kv_compression_must_keep_cpu = torch.zeros(
max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_must_keep_np = runner.kv_compression_must_keep_cpu.numpy()
runner.kv_compression_must_keep = torch.zeros(
max_num_tokens,
dtype=torch.bool,
device=device,
)
# Per-request: how many additional prompt tokens to keep among
# non-protected candidates (budget from env; selection uses scores).
runner.kv_compression_topk_budget_cpu = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_topk_budget_np = runner.kv_compression_topk_budget_cpu.numpy()
runner.kv_compression_topk_budget = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device=device,
)
# Chunked-prefill prompt-end KV compression metadata (scheme 3).
# Per-request: whether this step finishes the prompt and should compute
# global prompt indices (score/topk) for a one-shot compaction.
runner.kv_compression_prompt_end_cpu = torch.zeros(
max_num_reqs,
dtype=torch.bool,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_prompt_end_np = runner.kv_compression_prompt_end_cpu.numpy()
runner.kv_compression_prompt_end = torch.zeros(
max_num_reqs,
dtype=torch.bool,
device=device,
)
# Per-request: prompt length (tokens) and Top-K keep count among prompt
# candidates (excluding protected prefix/suffix).
runner.kv_compression_prompt_lens_cpu = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_prompt_lens_np = runner.kv_compression_prompt_lens_cpu.numpy()
runner.kv_compression_prompt_lens = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device=device,
)
runner.kv_compression_prompt_topk_keep_cpu = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_prompt_topk_keep_np = runner.kv_compression_prompt_topk_keep_cpu.numpy()
runner.kv_compression_prompt_topk_keep = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device=device,
)
runner.kv_compression_prompt_topk_keep_max = None # type: Optional[int]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Optional
import numpy as np
import vllm.envs as envs
from vllm.v1.kv_compression.budget import (compute_prompt_topk_keep_total,
compute_topk_budget_step)
def prepare_kv_compression_for_step(
*,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: np.ndarray, # [B] int32
cu_num_tokens: np.ndarray, # [B] int64/int32 cumulative scheduled tokens
req_indices: np.ndarray, # [T] int64, request index per token
arange: np.ndarray, # [T] int64, position-within-request per token
num_computed_tokens_cpu: np.ndarray, # [max_reqs] int32/int64
num_prompt_tokens: np.ndarray, # [max_reqs] int32/int64
num_kv_tokens_cpu: np.ndarray, # [max_reqs] int32/int64
kv_positions_np: np.ndarray, # [T] int64 (out)
must_keep_np: np.ndarray, # [T] bool (out; scheme 1/2 only)
topk_budget_np: np.ndarray, # [B] int32 (out; scheme 1/2 only)
prompt_end_np: np.ndarray, # [B] bool (out; scheme 3 only)
prompt_lens_np: np.ndarray, # [B] int32 (out; scheme 3 only)
prompt_topk_keep_np: np.ndarray, # [B] int32 (out; scheme 3 only)
chunked_prefill_enabled: bool,
) -> tuple[bool, Optional[int]]:
"""Prepare KV compression metadata for a single model step (CPU-side).
Fills:
- `kv_positions_np`: per-token KV write positions (decoupled from logical
RoPE positions).
- Scheme 3 (chunked prefill): `prompt_end/prompt_lens/prompt_topk_keep`.
- Scheme 1/2 (non-chunked): `must_keep/topk_budget`.
Returns:
(needs_compaction, prompt_topk_keep_max)
"""
if total_num_scheduled_tokens <= 0 or num_reqs <= 0:
return False, None
# KV positions (where scheduled tokens are written before optional
# compaction).
np.add(num_kv_tokens_cpu[req_indices], arange, out=kv_positions_np)
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
if chunked_prefill_enabled:
# Scheme 3: with chunked prefill, defer compaction until after the full
# prompt is ingested. Otherwise, the next prefill chunk would attend to
# a truncated history and quality can collapse.
prompt_end_np.fill(False)
prompt_lens_np.fill(0)
prompt_topk_keep_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
base_pos = int(num_computed_tokens_cpu[req_idx])
prompt_len = int(num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
ends_prompt = (base_pos < prompt_len) and (end_pos >= prompt_len)
if not ends_prompt:
continue
prompt_end_np[req_idx] = True
prompt_lens_np[req_idx] = prompt_len
prompt_topk_keep_np[req_idx] = compute_prompt_topk_keep_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
prompt_topk_keep_max = int(prompt_topk_keep_np[:num_reqs].max())
return False, prompt_topk_keep_max
# Scheme 1/2: per-step compaction within the scheduled segment.
must_keep_np.fill(False)
topk_budget_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
start = 0 if req_idx == 0 else int(cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx])
assert end - start == qlen
base_pos = int(num_computed_tokens_cpu[req_idx])
prompt_len = int(num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
pos_in_req = arange[start:end].astype(np.int64, copy=False)
pos = base_pos + pos_in_req
prompt_mask = pos < prompt_len
# Decode tokens are always kept.
must_keep = ~prompt_mask
if np.any(prompt_mask):
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
must_keep_np[start:end] = must_keep
# Decode-only fast path: if all scheduled tokens are unconditionally kept
# and there is no Top-K budget, KV compaction is a no-op and can be skipped.
needs_compaction = (not must_keep_np.all()) or (topk_budget_np > 0).any()
return bool(needs_compaction), None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any
import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import (
AttentionSpec,
CrossAttentionSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.kv_compression.forward_context import get_kv_compression_prompt_payload
def stash_kv_compression_prompt_payload_to_requests(*, runner: Any) -> None:
"""Persist prompt-end compaction indices from the forward context.
This is the runner-side half of chunked-prefill scheme 3:
flash_attn -> forward_context payload -> request state stash ->
(next step) one-shot KV compaction.
"""
if not envs.VLLM_ENABLE_KV_COMPRESSION:
return
scheduler_config = getattr(runner, "scheduler_config", None)
if scheduler_config is None or not getattr(scheduler_config,
"enable_chunked_prefill",
False):
return
forward_context = get_forward_context()
payload = get_kv_compression_prompt_payload(forward_context)
if payload is None:
return
req_indices = payload.get("req_indices")
idx_sorted = payload.get("idx_sorted")
keep_len = payload.get("keep_len")
prompt_lens = payload.get("prompt_lens")
if (req_indices is None or idx_sorted is None or keep_len is None
or prompt_lens is None):
return
input_batch = getattr(runner, "input_batch", None)
if input_batch is None:
return
req_ids = getattr(input_batch, "req_ids", None)
if req_ids is None:
return
requests = getattr(runner, "requests", None)
if requests is None:
return
req_indices_cpu = req_indices.to(device="cpu", dtype=torch.int64).tolist()
keep_cpu = keep_len.to(device="cpu", dtype=torch.int64).tolist()
prompt_cpu = prompt_lens.to(device="cpu", dtype=torch.int64).tolist()
for i, b in enumerate(req_indices_cpu):
if b < 0 or b >= len(req_ids):
continue
req_id = req_ids[b]
if req_id is None:
continue
rs = requests.get(req_id)
if rs is None:
continue
rs.kv_compression_prompt_idx_sorted = idx_sorted[i]
rs.kv_compression_prompt_keep_len = int(keep_cpu[i])
rs.kv_compression_prompt_prompt_len = int(prompt_cpu[i])
def maybe_apply_kv_compression_prompt_compaction(*, runner: Any) -> None:
"""Apply one-shot prompt KV compaction before the first decode step."""
if not envs.VLLM_ENABLE_KV_COMPRESSION:
return
if not current_platform.is_cuda_alike():
return
scheduler_config = getattr(runner, "scheduler_config", None)
if scheduler_config is None or not getattr(scheduler_config,
"enable_chunked_prefill",
False):
return
input_batch = getattr(runner, "input_batch", None)
if input_batch is None:
return
requests = getattr(runner, "requests", None)
if requests is None:
return
pending_req_ids: list[str] = []
for req_id in input_batch.req_ids:
if req_id is None:
continue
rs = requests.get(req_id)
if rs is None:
continue
if rs.kv_compression_prompt_idx_sorted is None:
continue
# Only apply once the prompt is fully ingested (decode stage).
if rs.num_computed_tokens < rs.num_prompt_tokens:
continue
pending_req_ids.append(req_id)
if not pending_req_ids:
return
device = runner.device
pending_states: list[tuple[str, torch.Tensor, int]] = []
for req_id in pending_req_ids:
rs = requests[req_id]
keep = rs.kv_compression_prompt_keep_len
idx = rs.kv_compression_prompt_idx_sorted
if keep is None or idx is None:
continue
keep_i = int(keep)
if keep_i <= 0:
# No prompt tokens kept; clear and skip.
rs.kv_compression_prompt_idx_sorted = None
rs.kv_compression_prompt_keep_len = None
rs.kv_compression_prompt_prompt_len = None
continue
pending_states.append((req_id, idx, keep_i))
if not pending_states:
return
B = len(pending_states)
keep_list = [k for _, _, k in pending_states]
K_max = max(keep_list)
idx_batch = torch.zeros((B, K_max), device=device, dtype=torch.int32)
for i, (_, row, k) in enumerate(pending_states):
idx_batch[i, :k] = row[:k].to(device=device, dtype=torch.int32)
keep_tensor = torch.tensor(keep_list, device=device, dtype=torch.int32)
from vllm.v1.kv_compression.kv_cache_triton import (
front_compact_inplace_fa_triton, make_fa_cache_view)
kv_cache_config = getattr(runner, "kv_cache_config", None)
if kv_cache_config is None:
return
# Apply compaction to every attention layer's KV cache in-place.
for group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
max_blocks = 0
for req_id, _, _ in pending_states:
rs = requests[req_id]
if group_id >= len(rs.block_ids):
continue
max_blocks = max(max_blocks, len(rs.block_ids[group_id]))
if max_blocks == 0:
continue
block_table_cpu = torch.zeros((B, max_blocks),
dtype=torch.int32,
device="cpu")
for i, (req_id, _, _) in enumerate(pending_states):
rs = requests[req_id]
if group_id >= len(rs.block_ids):
continue
ids = rs.block_ids[group_id]
if ids:
block_table_cpu[i, :len(ids)] = torch.tensor(ids,
dtype=torch.int32,
device="cpu")
block_table = block_table_cpu.to(device=device, non_blocking=True)
static_forward_context = getattr(
getattr(runner, "compilation_config", None),
"static_forward_context",
None,
)
if static_forward_context is None:
continue
seen_cache_ptrs: set[int] = set()
for layer_name in kv_cache_group_spec.layer_names:
# Skip non-self-attention caches (e.g., encoder/decoder cross-attn)
# and non-attention cache specs (e.g., Mamba).
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
kv_cache_spec = kv_cache_spec.kv_cache_specs.get(layer_name)
if kv_cache_spec is None or not isinstance(kv_cache_spec, AttentionSpec):
continue
if isinstance(kv_cache_spec, CrossAttentionSpec):
continue
layer = static_forward_context.get(layer_name)
if layer is None:
continue
kv_cache_list = getattr(layer, "kv_cache", None)
if not isinstance(kv_cache_list, list) or not kv_cache_list:
continue
kv_cache = kv_cache_list[0]
if not current_platform.is_rocm():
if not isinstance(kv_cache, torch.Tensor):
continue
cache_ptr = int(kv_cache.data_ptr())
if cache_ptr in seen_cache_ptrs:
continue
seen_cache_ptrs.add(cache_ptr)
key_cache, value_cache = kv_cache.unbind(0)
else:
if (not isinstance(kv_cache, (tuple, list))
or len(kv_cache) != 2):
continue
key_cache, value_cache = kv_cache
cache_ptr = int(key_cache.data_ptr())
if cache_ptr in seen_cache_ptrs:
continue
seen_cache_ptrs.add(cache_ptr)
k_view, v_view = make_fa_cache_view(key_cache=key_cache,
value_cache=value_cache)
front_compact_inplace_fa_triton(
k_view,
v_view,
block_table,
idx_batch,
keep_tensor,
)
# Clear pending state after successful compaction.
for req_id, _, _ in pending_states:
rs = requests.get(req_id)
if rs is None:
continue
rs.kv_compression_prompt_idx_sorted = None
rs.kv_compression_prompt_keep_len = None
rs.kv_compression_prompt_prompt_len = None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any, Optional
import numpy as np
import vllm.envs as envs
from vllm.v1.kv_compression.runner_prepare import prepare_kv_compression_for_step
def maybe_prepare_kv_compression_for_runner_step(
*,
runner: Any,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: np.ndarray, # [B] int32
cu_num_tokens: np.ndarray, # [B] int64/int32
req_indices: np.ndarray, # [T] int64
arange: np.ndarray, # [T] int64
) -> Optional[np.ndarray]:
"""Prepare per-step KV compression metadata on CPU.
Returns the per-token KV positions (`kv_positions_np`) or None if KV
compression is disabled.
"""
if not envs.VLLM_ENABLE_KV_COMPRESSION:
runner.kv_compression_needs_compaction = False
return None
kv_positions_np = runner.kv_positions_np[:total_num_scheduled_tokens]
must_keep_np = runner.kv_compression_must_keep_np[:total_num_scheduled_tokens]
topk_budget_np = runner.kv_compression_topk_budget_np[:num_reqs]
prompt_end_np = runner.kv_compression_prompt_end_np[:num_reqs]
prompt_lens_np = runner.kv_compression_prompt_lens_np[:num_reqs]
topk_keep_np = runner.kv_compression_prompt_topk_keep_np[:num_reqs]
needs_compaction, prompt_topk_keep_max = prepare_kv_compression_for_step(
num_reqs=num_reqs,
total_num_scheduled_tokens=total_num_scheduled_tokens,
num_scheduled_tokens=num_scheduled_tokens,
cu_num_tokens=cu_num_tokens,
req_indices=req_indices,
arange=arange,
num_computed_tokens_cpu=runner.input_batch.num_computed_tokens_cpu,
num_prompt_tokens=runner.input_batch.num_prompt_tokens,
num_kv_tokens_cpu=runner.input_batch.num_kv_tokens_cpu,
kv_positions_np=kv_positions_np,
must_keep_np=must_keep_np,
topk_budget_np=topk_budget_np,
prompt_end_np=prompt_end_np,
prompt_lens_np=prompt_lens_np,
prompt_topk_keep_np=topk_keep_np,
chunked_prefill_enabled=runner.scheduler_config.enable_chunked_prefill,
)
runner.kv_compression_needs_compaction = bool(needs_compaction)
if prompt_topk_keep_max is not None:
runner.kv_compression_prompt_topk_keep_max = int(prompt_topk_keep_max)
return kv_positions_np
def maybe_copy_kv_compression_step_tensors_to_gpu(
*,
runner: Any,
num_reqs: int,
total_num_scheduled_tokens: int,
non_blocking: bool = True,
) -> None:
"""Stage per-step KV compression tensors to GPU if needed."""
if not envs.VLLM_ENABLE_KV_COMPRESSION:
return
if runner.scheduler_config.enable_chunked_prefill:
runner.kv_compression_prompt_end[:num_reqs].copy_(
runner.kv_compression_prompt_end_cpu[:num_reqs],
non_blocking=non_blocking,
)
runner.kv_compression_prompt_lens[:num_reqs].copy_(
runner.kv_compression_prompt_lens_cpu[:num_reqs],
non_blocking=non_blocking,
)
runner.kv_compression_prompt_topk_keep[:num_reqs].copy_(
runner.kv_compression_prompt_topk_keep_cpu[:num_reqs],
non_blocking=non_blocking,
)
return
if runner.kv_compression_needs_compaction:
runner.kv_compression_must_keep[:total_num_scheduled_tokens].copy_(
runner.kv_compression_must_keep_cpu[:total_num_scheduled_tokens],
non_blocking=non_blocking,
)
runner.kv_compression_topk_budget[:num_reqs].copy_(
runner.kv_compression_topk_budget_cpu[:num_reqs],
non_blocking=non_blocking,
)
...@@ -97,6 +97,7 @@ from vllm.utils.torch_utils import ( ...@@ -97,6 +97,7 @@ from vllm.utils.torch_utils import (
get_dtype_size, get_dtype_size,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
) )
from vllm.platforms import current_platform
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -114,6 +115,16 @@ from vllm.v1.attention.backends.utils import ( ...@@ -114,6 +115,16 @@ from vllm.v1.attention.backends.utils import (
) )
from vllm.v1.core.sched.output import NewRequestData from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_compression.runner_buffers import init_kv_compression_runner_buffers
from vllm.v1.kv_compression.metadata import build_kv_compression_attn_metadata
from vllm.v1.kv_compression.runner_prompt_compaction import (
maybe_apply_kv_compression_prompt_compaction,
stash_kv_compression_prompt_payload_to_requests,
)
from vllm.v1.kv_compression.runner_step import (
maybe_copy_kv_compression_step_tensors_to_gpu,
maybe_prepare_kv_compression_for_runner_step,
)
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
...@@ -562,6 +573,16 @@ class GPUModelRunner( ...@@ -562,6 +573,16 @@ class GPUModelRunner(
# Persistent buffers for CUDA graphs. # Persistent buffers for CUDA graphs.
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
if envs.VLLM_ENABLE_KV_COMPRESSION and current_platform.is_cuda_alike():
init_kv_compression_runner_buffers(
runner=self,
max_num_tokens=self.max_num_tokens,
max_num_reqs=self.max_num_reqs,
device=self.device,
pin_memory=self.pin_memory,
)
else:
self.kv_compression_needs_compaction = False
self.query_start_loc = self._make_buffer( self.query_start_loc = self._make_buffer(
self.max_num_reqs + 1, dtype=torch.int32 self.max_num_reqs + 1, dtype=torch.int32
) )
...@@ -953,6 +974,7 @@ class GPUModelRunner( ...@@ -953,6 +974,7 @@ class GPUModelRunner(
generator=generator, generator=generator,
block_ids=new_req_data.block_ids, block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
num_kv_tokens=new_req_data.num_kv_tokens,
output_token_ids=[], output_token_ids=[],
lora_request=new_req_data.lora_request, lora_request=new_req_data.lora_request,
) )
...@@ -987,6 +1009,7 @@ class GPUModelRunner( ...@@ -987,6 +1009,7 @@ class GPUModelRunner(
for i, req_id in enumerate(req_data.req_ids): for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i] num_computed_tokens = req_data.num_computed_tokens[i]
num_kv_tokens = req_data.num_kv_tokens[i]
new_block_ids = req_data.new_block_ids[i] new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_id in req_data.resumed_req_ids resumed_from_preemption = req_id in req_data.resumed_req_ids
num_output_tokens = req_data.num_output_tokens[i] num_output_tokens = req_data.num_output_tokens[i]
...@@ -1014,10 +1037,12 @@ class GPUModelRunner( ...@@ -1014,10 +1037,12 @@ class GPUModelRunner(
num_accepted = valid_sampled_token_count[prev_req_index] - 1 num_accepted = valid_sampled_token_count[prev_req_index] - 1
num_rejected = req_state.prev_num_draft_len - num_accepted num_rejected = req_state.prev_num_draft_len - num_accepted
num_computed_tokens -= num_rejected num_computed_tokens -= num_rejected
num_kv_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted) req_state.output_token_ids.extend([-1] * num_accepted)
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
req_state.num_kv_tokens = num_kv_tokens
if not is_last_rank: if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
...@@ -1074,6 +1099,7 @@ class GPUModelRunner( ...@@ -1074,6 +1099,7 @@ class GPUModelRunner(
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
self.input_batch.num_kv_tokens_cpu[req_index] = num_kv_tokens
if new_block_ids is not None: if new_block_ids is not None:
self.input_batch.block_table.append_row(new_block_ids, req_index) self.input_batch.block_table.append_row(new_block_ids, req_index)
...@@ -1183,6 +1209,7 @@ class GPUModelRunner( ...@@ -1183,6 +1209,7 @@ class GPUModelRunner(
req_state.pooling_params = new_req_data.pooling_params req_state.pooling_params = new_req_data.pooling_params
req_state.block_ids = new_req_data.block_ids req_state.block_ids = new_req_data.block_ids
req_state.num_computed_tokens = new_req_data.num_computed_tokens req_state.num_computed_tokens = new_req_data.num_computed_tokens
req_state.num_kv_tokens = new_req_data.num_kv_tokens
req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds req_state.prompt_token_ids, req_state.prompt_embeds
) )
...@@ -1482,6 +1509,16 @@ class GPUModelRunner( ...@@ -1482,6 +1509,16 @@ class GPUModelRunner(
out=positions_np, out=positions_np,
) )
kv_positions_np = maybe_prepare_kv_compression_for_runner_step(
runner=self,
num_reqs=num_reqs,
total_num_scheduled_tokens=total_num_scheduled_tokens,
num_scheduled_tokens=num_scheduled_tokens,
cu_num_tokens=cu_num_tokens,
req_indices=req_indices,
arange=arange,
)
# Calculate M-RoPE positions. # Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope: if self.uses_mrope:
...@@ -1557,8 +1594,19 @@ class GPUModelRunner( ...@@ -1557,8 +1594,19 @@ class GPUModelRunner(
output_idx += num_sched output_idx += num_sched
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) positions_for_slot_mapping = (
kv_positions_np if kv_positions_np is not None else positions_np
)
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_for_slot_mapping
)
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
maybe_copy_kv_compression_step_tensors_to_gpu(
runner=self,
num_reqs=num_reqs,
total_num_scheduled_tokens=total_num_scheduled_tokens,
non_blocking=True,
)
# Prepare the attention metadata. # Prepare the attention metadata.
self.query_start_loc.np[0] = 0 self.query_start_loc.np[0] = 0
...@@ -1569,9 +1617,15 @@ class GPUModelRunner( ...@@ -1569,9 +1617,15 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
self.seq_lens.np[:num_reqs] = ( logical_seq_lens_np = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
) )
if envs.VLLM_ENABLE_KV_COMPRESSION:
self.seq_lens.np[:num_reqs] = (
self.input_batch.num_kv_tokens_cpu[:num_reqs] + num_scheduled_tokens
)
else:
self.seq_lens.np[:num_reqs] = logical_seq_lens_np
# Fill unused with 0 for full cuda graph mode. # Fill unused with 0 for full cuda graph mode.
self.seq_lens.np[num_reqs:].fill(0) self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu() self.seq_lens.copy_to_gpu()
...@@ -1582,7 +1636,7 @@ class GPUModelRunner( ...@@ -1582,7 +1636,7 @@ class GPUModelRunner(
# Record which requests should not be sampled, # Record which requests should not be sampled,
# so that we could clear the sampled tokens before returning # so that we could clear the sampled tokens before returning
self.discard_request_mask.np[:num_reqs] = ( self.discard_request_mask.np[:num_reqs] = (
self.seq_lens.np[:num_reqs] < num_tokens_np logical_seq_lens_np < num_tokens_np
) )
self.discard_request_mask.copy_to_gpu(num_reqs) self.discard_request_mask.copy_to_gpu(num_reqs)
...@@ -1749,12 +1803,25 @@ class GPUModelRunner( ...@@ -1749,12 +1803,25 @@ class GPUModelRunner(
], ],
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded, num_actual_tokens=num_tokens_padded,
num_unpadded_tokens=num_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
block_table_tensor=block_table_gid_0, block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0, slot_mapping=slot_mapping_gid_0,
causal=True, causal=True,
) )
kv_meta = build_kv_compression_attn_metadata(
runner=self,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
)
cm_base.kv_compression_must_keep = kv_meta.must_keep
cm_base.kv_compression_topk_budget = kv_meta.topk_budget
cm_base.kv_compression_topk_budget_max = kv_meta.topk_budget_max
cm_base.kv_compression_prompt_end = kv_meta.prompt_end
cm_base.kv_compression_prompt_lens = kv_meta.prompt_lens
cm_base.kv_compression_prompt_topk_keep = kv_meta.prompt_topk_keep
cm_base.kv_compression_prompt_topk_keep_max = kv_meta.prompt_topk_keep_max
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
...@@ -3510,6 +3577,10 @@ class GPUModelRunner( ...@@ -3510,6 +3577,10 @@ class GPUModelRunner(
self.model_config.is_encoder_decoder and num_encoder_reqs > 0 self.model_config.is_encoder_decoder and num_encoder_reqs > 0
) )
# Chunked prefill (scheme 3): apply one-shot prompt KV compaction before
# the first decode step writes/reads KV at the compressed positions.
maybe_apply_kv_compression_prompt_compaction(runner=self)
# Run the model. # Run the model.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
with ( with (
...@@ -3534,6 +3605,7 @@ class GPUModelRunner( ...@@ -3534,6 +3605,7 @@ class GPUModelRunner(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**model_kwargs, **model_kwargs,
) )
stash_kv_compression_prompt_payload_to_requests(runner=self)
with record_function_or_nullcontext("gpu_model_runner: postprocess"): with record_function_or_nullcontext("gpu_model_runner: postprocess"):
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
...@@ -5273,6 +5345,17 @@ class GPUModelRunner( ...@@ -5273,6 +5345,17 @@ class GPUModelRunner(
attention_backend_maps.append(attn_backends[0]) attention_backend_maps.append(attn_backends[0])
attention_backend_list.append(attn_backends[1]) attention_backend_list.append(attn_backends[1])
if envs.VLLM_ENABLE_KV_COMPRESSION and current_platform.is_cuda_alike():
for attn_backend_set in attention_backend_list:
for attn_backend in attn_backend_set:
if attn_backend.get_name() != "FLASH_ATTN":
raise ValueError(
"KV compression currently requires the FLASH_ATTN "
"attention backend. "
f"Got {attn_backend.get_name()} "
f"({attn_backend.full_cls_name()})."
)
# Resolve cudagraph_mode before actually initialize metadata_builders # Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode( self._check_and_update_cudagraph_mode(
attention_backend_list, kv_cache_config.kv_cache_groups attention_backend_list, kv_cache_config.kv_cache_groups
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment