Commit b0911b24 authored by laibao's avatar laibao
Browse files

feat(kvpress): 增加调度侧 KV 长度记账

parent 87b788bd
......@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsReader,
)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
......@@ -50,6 +51,12 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import SlidingWindowSpec
from vllm.v1.kv_compression.scheduler_accounting import (
maybe_init_num_kv_tokens_on_running_transition,
update_num_kv_tokens_after_schedule,
)
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
......@@ -204,6 +211,7 @@ class Scheduler(SchedulerInterface):
)
speculative_config = vllm_config.speculative_config
self.speculative_config = speculative_config
self.use_eagle = False
self.num_spec_tokens = self.num_lookahead_tokens = 0
if speculative_config:
......@@ -218,6 +226,70 @@ class Scheduler(SchedulerInterface):
self.full_cuda_graph = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
self.use_mla = vllm_config.model_config.use_mla
# KV compression is a cross-component feature: Scheduler handles gate +
# accounting; Worker generates slot_mapping/metadata; attention backend
# performs scoring/Top-K selection and KV rewrite/compaction.
#
# Gate early to avoid enabling KV-compression accounting/slot mapping
# on unsupported platforms or incompatible feature combinations.
self.kv_compression_enabled = (
envs.VLLM_ENABLE_KV_COMPRESSION and current_platform.is_cuda_alike()
)
if envs.VLLM_ENABLE_KV_COMPRESSION and not self.kv_compression_enabled:
logger.warning_once(
"KV compression is only supported on CUDA/ROCm; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1 on this platform."
)
if self.kv_compression_enabled:
if envs.VLLM_KV_COMPRESSION_POLICY != "topk":
raise ValueError("VLLM_KV_COMPRESSION_POLICY must be 'topk'.")
if any(
isinstance(group.kv_cache_spec, SlidingWindowSpec)
for group in kv_cache_config.kv_cache_groups
):
raise ValueError(
"KV compression is incompatible with sliding window attention."
)
if self.cache_config.enable_prefix_caching:
raise ValueError(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression."
)
if self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA graph mode."
)
if self.speculative_config is not None:
raise ValueError(
"KV compression is currently incompatible with speculative decoding."
)
if self.dcp_world_size > 1 or self.pcp_world_size > 1:
raise ValueError(
"KV compression is currently incompatible with context parallelism "
"(dcp_world_size > 1 or pcp_world_size > 1)."
)
backend = self.vllm_config.attention_config.backend
if backend is not None and backend != AttentionBackendEnum.FLASH_ATTN:
raise ValueError(
"KV compression currently requires the FLASH_ATTN backend. "
f"Got attention_config.backend={backend}."
)
if envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET < -1:
raise ValueError("VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1.")
if not (0.0 <= envs.VLLM_KV_COMPRESSION_PROMPT_RATIO <= 1.0):
raise ValueError("VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1].")
if envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0."
)
if envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0."
)
if envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW < 1:
raise ValueError("VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1.")
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
......@@ -775,6 +847,11 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
maybe_init_num_kv_tokens_on_running_transition(
request=request,
num_computed_tokens=num_computed_tokens,
kv_compression_enabled=self.kv_compression_enabled,
)
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
......@@ -1163,6 +1240,11 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
maybe_init_num_kv_tokens_on_running_transition(
request=request,
num_computed_tokens=num_computed_tokens,
kv_compression_enabled=self.kv_compression_enabled,
)
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
......@@ -1499,6 +1581,7 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.free(request)
request.status = RequestStatus.PREEMPTED
request.num_computed_tokens = 0
request.num_kv_tokens = 0
request.spec_token_ids.clear()
request.num_preemptions += 1
if self.log_stats:
......@@ -1520,7 +1603,15 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id]
start_pos = request.num_computed_tokens
request.num_computed_tokens += num_scheduled_token
update_num_kv_tokens_after_schedule(
request=request,
start_pos=start_pos,
num_scheduled_token=num_scheduled_token,
chunked_prefill_enabled=self.scheduler_config.enable_chunked_prefill,
kv_compression_enabled=self.kv_compression_enabled,
)
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
# may be updated again in _update_from_output for speculative
......@@ -1593,6 +1684,7 @@ class Scheduler(SchedulerInterface):
new_block_ids: list[tuple[list[int], ...] | None] = []
all_token_ids: dict[str, list[int]] = {}
num_computed_tokens: list[int] = []
num_kv_tokens: list[int] = []
num_output_tokens: list[int] = []
resumed_req_ids = set()
......@@ -1623,6 +1715,7 @@ class Scheduler(SchedulerInterface):
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
)
num_computed_tokens.append(req.num_computed_tokens)
num_kv_tokens.append(req.num_kv_tokens)
num_output_tokens.append(
req.num_output_tokens + req.num_output_placeholders
)
......@@ -1634,6 +1727,7 @@ class Scheduler(SchedulerInterface):
all_token_ids=all_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_kv_tokens=num_kv_tokens,
num_output_tokens=num_output_tokens,
)
......@@ -1892,6 +1986,8 @@ class Scheduler(SchedulerInterface):
# tokens.
if request.num_computed_tokens > 0:
request.num_computed_tokens -= num_rejected
if request.num_kv_tokens > 0:
request.num_kv_tokens -= num_rejected
# If async scheduling, num_output_placeholders also includes
# the scheduled spec tokens count and so is similarly adjusted.
if request.num_output_placeholders > 0:
......@@ -2519,6 +2615,11 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
maybe_init_num_kv_tokens_on_running_transition(
request=request,
num_computed_tokens=num_computed_tokens,
kv_compression_enabled=self.kv_compression_enabled,
)
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any
import vllm.envs as envs
from vllm.v1.kv_compression.budget import (compute_prompt_keep_len,
compute_topk_budget_step,
count_prompt_must_keep_in_range)
def maybe_init_num_kv_tokens_on_running_transition(
*,
request: Any,
num_computed_tokens: int,
kv_compression_enabled: bool,
) -> None:
"""在 request 切换为 RUNNING 时,必要时初始化 `request.num_kv_tokens`。
- 未开启 KV compression:KV 实际长度始终等于逻辑长度,直接令
`num_kv_tokens = num_computed_tokens` 即可。
- 开启 KV compression:大多数请求从 0 token 开始(`num_computed_tokens == 0`),
不需要额外初始化;但某些路径(例如 KV connector / cache hit)可能让一个请求在
进入 RUNNING 时已经“预先拥有”一段已计算的 token。如果不把 `num_kv_tokens`
初始化到同样的值,后续 KV 写入偏移(基于 `num_kv_tokens`)会从 0 开始,导致
slot_mapping/KV cache 写入错位。
"""
if not kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
return
if getattr(request, "num_kv_tokens", 0) == 0 and num_computed_tokens > 0:
request.num_kv_tokens = num_computed_tokens
def update_num_kv_tokens_after_schedule(
*,
request: Any,
start_pos: int,
num_scheduled_token: int,
chunked_prefill_enabled: bool,
kv_compression_enabled: bool,
) -> None:
"""在一次调度(一个 step)之后推进 `request.num_kv_tokens`。
这是 KV compression 的“调度侧记账”函数(不做打分/TopK/重写 KV),目的仅是
让 Scheduler 维护出“KV cache 实际长度”,以便后续:
- KV block 分配(allocate_slots)
- Worker 侧 slot_mapping / KV 写入偏移
- attention metadata 里的 `seq_lens`
都基于正确的 KV 长度工作。
关键概念(针对单个 request):
- `num_computed_tokens`:逻辑进度(token 位置 / RoPE index)。
- `num_kv_tokens`:KV cache 中“实际保留/存储”的 token 数。
开启 KV compression 后,prompt KV 可能被压缩(只保留一部分),因此
`num_kv_tokens` 可能小于 `num_computed_tokens`;但 decode token 始终全保留。
参数含义:
- `start_pos`:本 step 开始前的逻辑位置;本次新调度的逻辑区间为
`[start_pos, end_pos)`。
- `num_scheduled_token`:该 request 在本 step 被调度的 token 数(可能同时包含
prompt token 和 decode token)。
- `chunked_prefill_enabled`:是否启用 chunked prefill。
- `kv_compression_enabled`:调度器 gate 后的“是否启用 KV compression”。
注意:
- 如果这里记账错了,worker 可能用错误的 KV 基址/长度生成 slot_mapping,
从而导致 KV cache 读写错位甚至越界。
- chunked prefill 模式下,为避免“下一段 prefill 看不到完整历史”导致质量崩溃,
prompt 的 KV compaction 会延后到 prompt 结束后一次性执行(prompt-end one-shot)。
因此 `num_kv_tokens` 的更新策略与非 chunked 模式不同。
"""
if num_scheduled_token <= 0:
return
if not kv_compression_enabled:
# 未开启压缩:KV cache 长度与逻辑长度始终一致,直接累加即可。
request.num_kv_tokens += num_scheduled_token
return
# 开启 KV compression 后:prompt token 只保留子集;decode token 始终全保留。
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = start_pos + int(num_scheduled_token)
prompt_end = int(getattr(request, "num_prompt_tokens", 0))
# Chunked prefill:prefill 过程中不要“边 ingest 边压缩 prompt KV”。
# 否则下一段 chunk prefill 会注意力看不到完整历史(语义变化/质量崩溃)。
# 正确策略是:prefill 阶段暂时全量保留 prompt KV;等 prompt ingest 完成后,
# 在第一次 decode 前做一次性 prompt compaction(prompt-end one-shot)。
if chunked_prefill_enabled:
if start_pos >= prompt_end:
# 纯 decode 段:decode token 始终全保留,KV 长度直接累加。
request.num_kv_tokens += num_scheduled_token
return
if end_pos < prompt_end:
# prompt 还没 ingest 完:暂时先全保留(不做 mid-prefill 压缩),KV 长度累加。
request.num_kv_tokens += num_scheduled_token
return
# 重要:这里是“重置”而不是“累加”。
# 因为 prompt 结束后 KV 长度会发生不连续跳变:
# - prompt ingest 过程中:KV cache 中存的是“完整 prompt 前缀”
# - prompt ingest 完成后:KV cache 中应变为“压缩后的 prompt”
# 实际的 in-place compaction 在 worker 侧 decode 前执行;这里先把记账值更新到位。
kept_prompt_total = compute_prompt_keep_len(
prompt_len=prompt_end,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
# 如果本 step 跨过 prompt_end,prompt_end 之后的 token 属于 decode,仍需全保留。
kept_decode = max(0, end_pos - max(start_pos, prompt_end))
request.num_kv_tokens = kept_prompt_total + kept_decode
return
# 非 chunked prefill(scheme 1/2):每个 step 内做 token-shared 的选择。
# - decode token:始终全保留;
# - prompt token:只保留 must-keep(protected prefix/suffix/可选最后token)
# + 本 step Top-K 选中的部分。
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
# 本 step 的逻辑区间内,prompt token 里“必须保留”的部分:
# protected_prefix / protected_suffix /(可选)最后一个 prompt token。
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
)
# 本 step 通过 Top-K 策略“额外保留”的 prompt token 数。
# 预算定义在 prompt 的“非保护区”上,并由 `compute_topk_budget_step` 按 step 分摊。
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
# 本 step 结束后:KV cache 实际长度按“保留的 KV 条目数”推进。
request.num_kv_tokens += (kept_decode + kept_prompt_must_keep +
kept_prompt_topk)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment