Commit 9d4330d2 authored by laibao's avatar laibao
Browse files

refactor: 抽离 runner 侧 KV compression 逻辑并统一 slot mapping

parent 3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any, Optional
import torch
def init_kv_compression_runner_buffers(
*,
runner: Any,
max_num_tokens: int,
max_num_reqs: int,
device: torch.device,
pin_memory: bool,
) -> None:
"""Initialize per-runner buffers used by KV compression.
This helper keeps `gpu_model_runner.py` focused on orchestration while
preserving the existing attribute-based access patterns.
"""
# KV positions are decoupled from logical positions when KV compression is
# enabled. Keep a separate buffer to avoid recomputing or overwriting the
# logical `positions_np` (used for RoPE / token lookup).
runner.kv_positions_cpu = torch.zeros(
max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_positions_np = runner.kv_positions_cpu.numpy()
# KV compression metadata buffers (used by the "topk" policy).
# Per-token: whether this scheduled token must be kept in KV cache.
runner.kv_compression_must_keep_cpu = torch.zeros(
max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_must_keep_np = runner.kv_compression_must_keep_cpu.numpy()
runner.kv_compression_must_keep = torch.zeros(
max_num_tokens,
dtype=torch.bool,
device=device,
)
# Per-request: how many additional prompt tokens to keep among
# non-protected candidates (budget from env; selection uses scores).
runner.kv_compression_topk_budget_cpu = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_topk_budget_np = runner.kv_compression_topk_budget_cpu.numpy()
runner.kv_compression_topk_budget = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device=device,
)
# Chunked-prefill prompt-end KV compression metadata (scheme 3).
# Per-request: whether this step finishes the prompt and should compute
# global prompt indices (score/topk) for a one-shot compaction.
runner.kv_compression_prompt_end_cpu = torch.zeros(
max_num_reqs,
dtype=torch.bool,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_prompt_end_np = runner.kv_compression_prompt_end_cpu.numpy()
runner.kv_compression_prompt_end = torch.zeros(
max_num_reqs,
dtype=torch.bool,
device=device,
)
# Per-request: prompt length (tokens) and Top-K keep count among prompt
# candidates (excluding protected prefix/suffix).
runner.kv_compression_prompt_lens_cpu = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_prompt_lens_np = runner.kv_compression_prompt_lens_cpu.numpy()
runner.kv_compression_prompt_lens = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device=device,
)
runner.kv_compression_prompt_topk_keep_cpu = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory,
)
runner.kv_compression_prompt_topk_keep_np = runner.kv_compression_prompt_topk_keep_cpu.numpy()
runner.kv_compression_prompt_topk_keep = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device=device,
)
runner.kv_compression_prompt_topk_keep_max = None # type: Optional[int]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Optional
import numpy as np
import vllm.envs as envs
from vllm.v1.kv_compression.budget import (compute_prompt_topk_keep_total,
compute_topk_budget_step)
def prepare_kv_compression_for_step(
*,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: np.ndarray, # [B] int32
cu_num_tokens: np.ndarray, # [B] int64/int32 cumulative scheduled tokens
req_indices: np.ndarray, # [T] int64, request index per token
arange: np.ndarray, # [T] int64, position-within-request per token
num_computed_tokens_cpu: np.ndarray, # [max_reqs] int32/int64
num_prompt_tokens: np.ndarray, # [max_reqs] int32/int64
num_kv_tokens_cpu: np.ndarray, # [max_reqs] int32/int64
kv_positions_np: np.ndarray, # [T] int64 (out)
must_keep_np: np.ndarray, # [T] bool (out; scheme 1/2 only)
topk_budget_np: np.ndarray, # [B] int32 (out; scheme 1/2 only)
prompt_end_np: np.ndarray, # [B] bool (out; scheme 3 only)
prompt_lens_np: np.ndarray, # [B] int32 (out; scheme 3 only)
prompt_topk_keep_np: np.ndarray, # [B] int32 (out; scheme 3 only)
chunked_prefill_enabled: bool,
) -> tuple[bool, Optional[int]]:
"""Prepare KV compression metadata for a single model step (CPU-side).
Fills:
- `kv_positions_np`: per-token KV write positions (decoupled from logical
RoPE positions).
- Scheme 3 (chunked prefill): `prompt_end/prompt_lens/prompt_topk_keep`.
- Scheme 1/2 (non-chunked): `must_keep/topk_budget`.
Returns:
(needs_compaction, prompt_topk_keep_max)
"""
if total_num_scheduled_tokens <= 0 or num_reqs <= 0:
return False, None
# KV positions (where scheduled tokens are written before optional
# compaction).
np.add(num_kv_tokens_cpu[req_indices], arange, out=kv_positions_np)
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
if chunked_prefill_enabled:
# Scheme 3: with chunked prefill, defer compaction until after the full
# prompt is ingested. Otherwise, the next prefill chunk would attend to
# a truncated history and quality can collapse.
prompt_end_np.fill(False)
prompt_lens_np.fill(0)
prompt_topk_keep_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
base_pos = int(num_computed_tokens_cpu[req_idx])
prompt_len = int(num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
ends_prompt = (base_pos < prompt_len) and (end_pos >= prompt_len)
if not ends_prompt:
continue
prompt_end_np[req_idx] = True
prompt_lens_np[req_idx] = prompt_len
prompt_topk_keep_np[req_idx] = compute_prompt_topk_keep_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
prompt_topk_keep_max = int(prompt_topk_keep_np[:num_reqs].max())
return False, prompt_topk_keep_max
# Scheme 1/2: per-step compaction within the scheduled segment.
must_keep_np.fill(False)
topk_budget_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
start = 0 if req_idx == 0 else int(cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx])
assert end - start == qlen
base_pos = int(num_computed_tokens_cpu[req_idx])
prompt_len = int(num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
pos_in_req = arange[start:end].astype(np.int64, copy=False)
pos = base_pos + pos_in_req
prompt_mask = pos < prompt_len
# Decode tokens are always kept.
must_keep = ~prompt_mask
if np.any(prompt_mask):
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
must_keep_np[start:end] = must_keep
# Decode-only fast path: if all scheduled tokens are unconditionally kept
# and there is no Top-K budget, KV compaction is a no-op and can be skipped.
needs_compaction = (not must_keep_np.all()) or (topk_budget_np > 0).any()
return bool(needs_compaction), None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any
import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
from vllm.v1.kv_compression.forward_context import get_kv_compression_prompt_payload
def stash_kv_compression_prompt_payload_to_requests(*, runner: Any) -> None:
"""Persist prompt-end compaction indices from the forward context.
This is the runner-side half of chunked-prefill scheme 3:
flash_attn -> forward_context payload -> request state stash ->
(next step) one-shot KV compaction.
"""
if not envs.VLLM_ENABLE_KV_COMPRESSION:
return
scheduler_config = getattr(runner, "scheduler_config", None)
if scheduler_config is None or not getattr(scheduler_config,
"chunked_prefill_enabled",
False):
return
forward_context = get_forward_context()
payload = get_kv_compression_prompt_payload(forward_context)
if payload is None:
return
req_indices = payload.get("req_indices")
idx_sorted = payload.get("idx_sorted")
keep_len = payload.get("keep_len")
prompt_lens = payload.get("prompt_lens")
if (req_indices is None or idx_sorted is None or keep_len is None
or prompt_lens is None):
return
input_batch = getattr(runner, "input_batch", None)
if input_batch is None:
return
req_ids = getattr(input_batch, "req_ids", None)
if req_ids is None:
return
requests = getattr(runner, "requests", None)
if requests is None:
return
req_indices_cpu = req_indices.to(device="cpu", dtype=torch.int64).tolist()
keep_cpu = keep_len.to(device="cpu", dtype=torch.int64).tolist()
prompt_cpu = prompt_lens.to(device="cpu", dtype=torch.int64).tolist()
for i, b in enumerate(req_indices_cpu):
if b < 0 or b >= len(req_ids):
continue
req_id = req_ids[b]
if req_id is None:
continue
rs = requests.get(req_id)
if rs is None:
continue
rs.kv_compression_prompt_idx_sorted = idx_sorted[i]
rs.kv_compression_prompt_keep_len = int(keep_cpu[i])
rs.kv_compression_prompt_prompt_len = int(prompt_cpu[i])
def maybe_apply_kv_compression_prompt_compaction(*, runner: Any) -> None:
"""Apply one-shot prompt KV compaction before the first decode step."""
if not envs.VLLM_ENABLE_KV_COMPRESSION:
return
scheduler_config = getattr(runner, "scheduler_config", None)
if scheduler_config is None or not getattr(scheduler_config,
"chunked_prefill_enabled",
False):
return
input_batch = getattr(runner, "input_batch", None)
if input_batch is None:
return
requests = getattr(runner, "requests", None)
if requests is None:
return
pending_req_ids: list[str] = []
for req_id in input_batch.req_ids:
if req_id is None:
continue
rs = requests.get(req_id)
if rs is None:
continue
if rs.kv_compression_prompt_idx_sorted is None:
continue
# Only apply once the prompt is fully ingested (decode stage).
if rs.num_computed_tokens < rs.num_prompt_tokens:
continue
pending_req_ids.append(req_id)
if not pending_req_ids:
return
device = runner.device
pending_states: list[tuple[str, torch.Tensor, int]] = []
for req_id in pending_req_ids:
rs = requests[req_id]
keep = rs.kv_compression_prompt_keep_len
idx = rs.kv_compression_prompt_idx_sorted
if keep is None or idx is None:
continue
keep_i = int(keep)
if keep_i <= 0:
# No prompt tokens kept; clear and skip.
rs.kv_compression_prompt_idx_sorted = None
rs.kv_compression_prompt_keep_len = None
rs.kv_compression_prompt_prompt_len = None
continue
pending_states.append((req_id, idx, keep_i))
if not pending_states:
return
B = len(pending_states)
keep_list = [k for _, _, k in pending_states]
K_max = max(keep_list)
idx_batch = torch.zeros((B, K_max), device=device, dtype=torch.int32)
for i, (_, row, k) in enumerate(pending_states):
idx_batch[i, :k] = row[:k].to(device=device, dtype=torch.int32)
keep_tensor = torch.tensor(keep_list, device=device, dtype=torch.int32)
from vllm.v1.kv_compression.kv_cache_triton import (
front_compact_inplace_fa_triton, make_fa_cache_view)
kv_cache_config = getattr(runner, "kv_cache_config", None)
if kv_cache_config is None:
return
# Apply compaction to every attention layer's KV cache in-place.
for group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
max_blocks = 0
for req_id, _, _ in pending_states:
rs = requests[req_id]
if group_id >= len(rs.block_ids):
continue
max_blocks = max(max_blocks, len(rs.block_ids[group_id]))
if max_blocks == 0:
continue
block_table_cpu = torch.zeros((B, max_blocks),
dtype=torch.int32,
device="cpu")
for i, (req_id, _, _) in enumerate(pending_states):
rs = requests[req_id]
if group_id >= len(rs.block_ids):
continue
ids = rs.block_ids[group_id]
if ids:
block_table_cpu[i, :len(ids)] = torch.tensor(ids,
dtype=torch.int32,
device="cpu")
block_table = block_table_cpu.to(device=device, non_blocking=True)
kv_caches = getattr(runner, "kv_caches", None)
if kv_caches is None:
continue
for layer_name in kv_cache_group_spec.layer_names:
layer_index = runner._extract_layer_index(layer_name)
if layer_index >= len(kv_caches):
continue
kv_cache = kv_caches[layer_index]
if not current_platform.is_rocm():
if not isinstance(kv_cache, torch.Tensor):
continue
key_cache, value_cache = kv_cache.unbind(0)
else:
if (not isinstance(kv_cache, (tuple, list))
or len(kv_cache) != 2):
continue
key_cache, value_cache = kv_cache
k_view, v_view = make_fa_cache_view(key_cache=key_cache,
value_cache=value_cache)
front_compact_inplace_fa_triton(
k_view,
v_view,
block_table,
idx_batch,
keep_tensor,
)
# Clear pending state after successful compaction.
for req_id, _, _ in pending_states:
rs = requests.get(req_id)
if rs is None:
continue
rs.kv_compression_prompt_idx_sorted = None
rs.kv_compression_prompt_keep_len = None
rs.kv_compression_prompt_prompt_len = None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any, Optional
import numpy as np
import vllm.envs as envs
from vllm.v1.kv_compression.runner_prepare import prepare_kv_compression_for_step
def maybe_prepare_kv_compression_for_runner_step(
*,
runner: Any,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: np.ndarray, # [B] int32
cu_num_tokens: np.ndarray, # [B] int64/int32
req_indices: np.ndarray, # [T] int64
arange: np.ndarray, # [T] int64
) -> Optional[np.ndarray]:
"""Prepare per-step KV compression metadata on CPU.
Returns the per-token KV positions (`kv_positions_np`) or None if KV
compression is disabled.
"""
if not envs.VLLM_ENABLE_KV_COMPRESSION:
runner.kv_compression_needs_compaction = False
return None
kv_positions_np = runner.kv_positions_np[:total_num_scheduled_tokens]
must_keep_np = runner.kv_compression_must_keep_np[:total_num_scheduled_tokens]
topk_budget_np = runner.kv_compression_topk_budget_np[:num_reqs]
prompt_end_np = runner.kv_compression_prompt_end_np[:num_reqs]
prompt_lens_np = runner.kv_compression_prompt_lens_np[:num_reqs]
topk_keep_np = runner.kv_compression_prompt_topk_keep_np[:num_reqs]
needs_compaction, prompt_topk_keep_max = prepare_kv_compression_for_step(
num_reqs=num_reqs,
total_num_scheduled_tokens=total_num_scheduled_tokens,
num_scheduled_tokens=num_scheduled_tokens,
cu_num_tokens=cu_num_tokens,
req_indices=req_indices,
arange=arange,
num_computed_tokens_cpu=runner.input_batch.num_computed_tokens_cpu,
num_prompt_tokens=runner.input_batch.num_prompt_tokens,
num_kv_tokens_cpu=runner.input_batch.num_kv_tokens_cpu,
kv_positions_np=kv_positions_np,
must_keep_np=must_keep_np,
topk_budget_np=topk_budget_np,
prompt_end_np=prompt_end_np,
prompt_lens_np=prompt_lens_np,
prompt_topk_keep_np=topk_keep_np,
chunked_prefill_enabled=runner.scheduler_config.chunked_prefill_enabled,
)
runner.kv_compression_needs_compaction = bool(needs_compaction)
if prompt_topk_keep_max is not None:
runner.kv_compression_prompt_topk_keep_max = int(prompt_topk_keep_max)
return kv_positions_np
def maybe_copy_kv_compression_step_tensors_to_gpu(
*,
runner: Any,
num_reqs: int,
total_num_scheduled_tokens: int,
non_blocking: bool = True,
) -> None:
"""Stage per-step KV compression tensors to GPU if needed."""
if not envs.VLLM_ENABLE_KV_COMPRESSION:
return
if runner.scheduler_config.chunked_prefill_enabled:
runner.kv_compression_prompt_end[:num_reqs].copy_(
runner.kv_compression_prompt_end_cpu[:num_reqs],
non_blocking=non_blocking,
)
runner.kv_compression_prompt_lens[:num_reqs].copy_(
runner.kv_compression_prompt_lens_cpu[:num_reqs],
non_blocking=non_blocking,
)
runner.kv_compression_prompt_topk_keep[:num_reqs].copy_(
runner.kv_compression_prompt_topk_keep_cpu[:num_reqs],
non_blocking=non_blocking,
)
return
if runner.kv_compression_needs_compaction:
runner.kv_compression_must_keep[:total_num_scheduled_tokens].copy_(
runner.kv_compression_must_keep_cpu[:total_num_scheduled_tokens],
non_blocking=non_blocking,
)
runner.kv_compression_topk_budget[:num_reqs].copy_(
runner.kv_compression_topk_budget_cpu[:num_reqs],
non_blocking=non_blocking,
)
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import numpy as np
from vllm.v1.worker.block_table import BlockTable
def fill_block_table_slot_mapping_np(
*,
block_table: BlockTable,
req_indices: np.ndarray, # [T] int64
slot_positions_np: np.ndarray, # [T] int64
total_num_scheduled_tokens: int,
block_size: int,
) -> None:
"""Fill `block_table.slot_mapping_np` for a packed batch of tokens."""
block_table_indices = (req_indices * block_table.max_num_blocks_per_req +
slot_positions_np // int(block_size))
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = slot_positions_np % int(block_size)
np.add(
block_numbers * int(block_size),
block_offsets,
out=block_table.slot_mapping_np[:total_num_scheduled_tokens],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment