Commit b0911b24 authored by laibao's avatar laibao
Browse files

feat(kvpress): 增加调度侧 KV 长度记账

parent 87b788bd
...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( ...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsReader, RoutedExpertsReader,
) )
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.core.encoder_cache_manager import ( from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager, EncoderCacheManager,
EncoderDecoderCacheManager, EncoderDecoderCacheManager,
...@@ -50,6 +51,12 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu ...@@ -50,6 +51,12 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import SlidingWindowSpec
from vllm.v1.kv_compression.scheduler_accounting import (
maybe_init_num_kv_tokens_on_running_transition,
update_num_kv_tokens_after_schedule,
)
from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
...@@ -204,6 +211,7 @@ class Scheduler(SchedulerInterface): ...@@ -204,6 +211,7 @@ class Scheduler(SchedulerInterface):
) )
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config
self.speculative_config = speculative_config
self.use_eagle = False self.use_eagle = False
self.num_spec_tokens = self.num_lookahead_tokens = 0 self.num_spec_tokens = self.num_lookahead_tokens = 0
if speculative_config: if speculative_config:
...@@ -218,6 +226,70 @@ class Scheduler(SchedulerInterface): ...@@ -218,6 +226,70 @@ class Scheduler(SchedulerInterface):
self.full_cuda_graph = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL self.full_cuda_graph = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
self.use_mla = vllm_config.model_config.use_mla self.use_mla = vllm_config.model_config.use_mla
# KV compression is a cross-component feature: Scheduler handles gate +
# accounting; Worker generates slot_mapping/metadata; attention backend
# performs scoring/Top-K selection and KV rewrite/compaction.
#
# Gate early to avoid enabling KV-compression accounting/slot mapping
# on unsupported platforms or incompatible feature combinations.
self.kv_compression_enabled = (
envs.VLLM_ENABLE_KV_COMPRESSION and current_platform.is_cuda_alike()
)
if envs.VLLM_ENABLE_KV_COMPRESSION and not self.kv_compression_enabled:
logger.warning_once(
"KV compression is only supported on CUDA/ROCm; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1 on this platform."
)
if self.kv_compression_enabled:
if envs.VLLM_KV_COMPRESSION_POLICY != "topk":
raise ValueError("VLLM_KV_COMPRESSION_POLICY must be 'topk'.")
if any(
isinstance(group.kv_cache_spec, SlidingWindowSpec)
for group in kv_cache_config.kv_cache_groups
):
raise ValueError(
"KV compression is incompatible with sliding window attention."
)
if self.cache_config.enable_prefix_caching:
raise ValueError(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression."
)
if self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA graph mode."
)
if self.speculative_config is not None:
raise ValueError(
"KV compression is currently incompatible with speculative decoding."
)
if self.dcp_world_size > 1 or self.pcp_world_size > 1:
raise ValueError(
"KV compression is currently incompatible with context parallelism "
"(dcp_world_size > 1 or pcp_world_size > 1)."
)
backend = self.vllm_config.attention_config.backend
if backend is not None and backend != AttentionBackendEnum.FLASH_ATTN:
raise ValueError(
"KV compression currently requires the FLASH_ATTN backend. "
f"Got attention_config.backend={backend}."
)
if envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET < -1:
raise ValueError("VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1.")
if not (0.0 <= envs.VLLM_KV_COMPRESSION_PROMPT_RATIO <= 1.0):
raise ValueError("VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1].")
if envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0."
)
if envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0."
)
if envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW < 1:
raise ValueError("VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1.")
# Create the KV cache manager. # Create the KV cache manager.
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
...@@ -775,6 +847,11 @@ class Scheduler(SchedulerInterface): ...@@ -775,6 +847,11 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
maybe_init_num_kv_tokens_on_running_transition(
request=request,
num_computed_tokens=num_computed_tokens,
kv_compression_enabled=self.kv_compression_enabled,
)
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -1163,6 +1240,11 @@ class Scheduler(SchedulerInterface): ...@@ -1163,6 +1240,11 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
maybe_init_num_kv_tokens_on_running_transition(
request=request,
num_computed_tokens=num_computed_tokens,
kv_compression_enabled=self.kv_compression_enabled,
)
# Count the number of prefix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
...@@ -1499,6 +1581,7 @@ class Scheduler(SchedulerInterface): ...@@ -1499,6 +1581,7 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.free(request) self.encoder_cache_manager.free(request)
request.status = RequestStatus.PREEMPTED request.status = RequestStatus.PREEMPTED
request.num_computed_tokens = 0 request.num_computed_tokens = 0
request.num_kv_tokens = 0
request.spec_token_ids.clear() request.spec_token_ids.clear()
request.num_preemptions += 1 request.num_preemptions += 1
if self.log_stats: if self.log_stats:
...@@ -1520,7 +1603,15 @@ class Scheduler(SchedulerInterface): ...@@ -1520,7 +1603,15 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items(): for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id] request = self.requests[req_id]
start_pos = request.num_computed_tokens
request.num_computed_tokens += num_scheduled_token request.num_computed_tokens += num_scheduled_token
update_num_kv_tokens_after_schedule(
request=request,
start_pos=start_pos,
num_scheduled_token=num_scheduled_token,
chunked_prefill_enabled=self.scheduler_config.enable_chunked_prefill,
kv_compression_enabled=self.kv_compression_enabled,
)
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which # NOTE: _free_encoder_inputs relies on num_computed_tokens, which
# may be updated again in _update_from_output for speculative # may be updated again in _update_from_output for speculative
...@@ -1593,6 +1684,7 @@ class Scheduler(SchedulerInterface): ...@@ -1593,6 +1684,7 @@ class Scheduler(SchedulerInterface):
new_block_ids: list[tuple[list[int], ...] | None] = [] new_block_ids: list[tuple[list[int], ...] | None] = []
all_token_ids: dict[str, list[int]] = {} all_token_ids: dict[str, list[int]] = {}
num_computed_tokens: list[int] = [] num_computed_tokens: list[int] = []
num_kv_tokens: list[int] = []
num_output_tokens: list[int] = [] num_output_tokens: list[int] = []
resumed_req_ids = set() resumed_req_ids = set()
...@@ -1623,6 +1715,7 @@ class Scheduler(SchedulerInterface): ...@@ -1623,6 +1715,7 @@ class Scheduler(SchedulerInterface):
req_to_new_blocks[req_id].get_block_ids(allow_none=True) req_to_new_blocks[req_id].get_block_ids(allow_none=True)
) )
num_computed_tokens.append(req.num_computed_tokens) num_computed_tokens.append(req.num_computed_tokens)
num_kv_tokens.append(req.num_kv_tokens)
num_output_tokens.append( num_output_tokens.append(
req.num_output_tokens + req.num_output_placeholders req.num_output_tokens + req.num_output_placeholders
) )
...@@ -1634,6 +1727,7 @@ class Scheduler(SchedulerInterface): ...@@ -1634,6 +1727,7 @@ class Scheduler(SchedulerInterface):
all_token_ids=all_token_ids, all_token_ids=all_token_ids,
new_block_ids=new_block_ids, new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens, num_computed_tokens=num_computed_tokens,
num_kv_tokens=num_kv_tokens,
num_output_tokens=num_output_tokens, num_output_tokens=num_output_tokens,
) )
...@@ -1892,6 +1986,8 @@ class Scheduler(SchedulerInterface): ...@@ -1892,6 +1986,8 @@ class Scheduler(SchedulerInterface):
# tokens. # tokens.
if request.num_computed_tokens > 0: if request.num_computed_tokens > 0:
request.num_computed_tokens -= num_rejected request.num_computed_tokens -= num_rejected
if request.num_kv_tokens > 0:
request.num_kv_tokens -= num_rejected
# If async scheduling, num_output_placeholders also includes # If async scheduling, num_output_placeholders also includes
# the scheduled spec tokens count and so is similarly adjusted. # the scheduled spec tokens count and so is similarly adjusted.
if request.num_output_placeholders > 0: if request.num_output_placeholders > 0:
...@@ -2519,6 +2615,11 @@ class Scheduler(SchedulerInterface): ...@@ -2519,6 +2615,11 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling. # Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
maybe_init_num_kv_tokens_on_running_transition(
request=request,
num_computed_tokens=num_computed_tokens,
kv_compression_enabled=self.kv_compression_enabled,
)
# Return that we are ready. # Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id) self.finished_recving_kv_req_ids.remove(request.request_id)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Any
import vllm.envs as envs
from vllm.v1.kv_compression.budget import (compute_prompt_keep_len,
compute_topk_budget_step,
count_prompt_must_keep_in_range)
def maybe_init_num_kv_tokens_on_running_transition(
*,
request: Any,
num_computed_tokens: int,
kv_compression_enabled: bool,
) -> None:
"""在 request 切换为 RUNNING 时,必要时初始化 `request.num_kv_tokens`。
- 未开启 KV compression:KV 实际长度始终等于逻辑长度,直接令
`num_kv_tokens = num_computed_tokens` 即可。
- 开启 KV compression:大多数请求从 0 token 开始(`num_computed_tokens == 0`),
不需要额外初始化;但某些路径(例如 KV connector / cache hit)可能让一个请求在
进入 RUNNING 时已经“预先拥有”一段已计算的 token。如果不把 `num_kv_tokens`
初始化到同样的值,后续 KV 写入偏移(基于 `num_kv_tokens`)会从 0 开始,导致
slot_mapping/KV cache 写入错位。
"""
if not kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
return
if getattr(request, "num_kv_tokens", 0) == 0 and num_computed_tokens > 0:
request.num_kv_tokens = num_computed_tokens
def update_num_kv_tokens_after_schedule(
*,
request: Any,
start_pos: int,
num_scheduled_token: int,
chunked_prefill_enabled: bool,
kv_compression_enabled: bool,
) -> None:
"""在一次调度(一个 step)之后推进 `request.num_kv_tokens`。
这是 KV compression 的“调度侧记账”函数(不做打分/TopK/重写 KV),目的仅是
让 Scheduler 维护出“KV cache 实际长度”,以便后续:
- KV block 分配(allocate_slots)
- Worker 侧 slot_mapping / KV 写入偏移
- attention metadata 里的 `seq_lens`
都基于正确的 KV 长度工作。
关键概念(针对单个 request):
- `num_computed_tokens`:逻辑进度(token 位置 / RoPE index)。
- `num_kv_tokens`:KV cache 中“实际保留/存储”的 token 数。
开启 KV compression 后,prompt KV 可能被压缩(只保留一部分),因此
`num_kv_tokens` 可能小于 `num_computed_tokens`;但 decode token 始终全保留。
参数含义:
- `start_pos`:本 step 开始前的逻辑位置;本次新调度的逻辑区间为
`[start_pos, end_pos)`。
- `num_scheduled_token`:该 request 在本 step 被调度的 token 数(可能同时包含
prompt token 和 decode token)。
- `chunked_prefill_enabled`:是否启用 chunked prefill。
- `kv_compression_enabled`:调度器 gate 后的“是否启用 KV compression”。
注意:
- 如果这里记账错了,worker 可能用错误的 KV 基址/长度生成 slot_mapping,
从而导致 KV cache 读写错位甚至越界。
- chunked prefill 模式下,为避免“下一段 prefill 看不到完整历史”导致质量崩溃,
prompt 的 KV compaction 会延后到 prompt 结束后一次性执行(prompt-end one-shot)。
因此 `num_kv_tokens` 的更新策略与非 chunked 模式不同。
"""
if num_scheduled_token <= 0:
return
if not kv_compression_enabled:
# 未开启压缩:KV cache 长度与逻辑长度始终一致,直接累加即可。
request.num_kv_tokens += num_scheduled_token
return
# 开启 KV compression 后:prompt token 只保留子集;decode token 始终全保留。
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = start_pos + int(num_scheduled_token)
prompt_end = int(getattr(request, "num_prompt_tokens", 0))
# Chunked prefill:prefill 过程中不要“边 ingest 边压缩 prompt KV”。
# 否则下一段 chunk prefill 会注意力看不到完整历史(语义变化/质量崩溃)。
# 正确策略是:prefill 阶段暂时全量保留 prompt KV;等 prompt ingest 完成后,
# 在第一次 decode 前做一次性 prompt compaction(prompt-end one-shot)。
if chunked_prefill_enabled:
if start_pos >= prompt_end:
# 纯 decode 段:decode token 始终全保留,KV 长度直接累加。
request.num_kv_tokens += num_scheduled_token
return
if end_pos < prompt_end:
# prompt 还没 ingest 完:暂时先全保留(不做 mid-prefill 压缩),KV 长度累加。
request.num_kv_tokens += num_scheduled_token
return
# 重要:这里是“重置”而不是“累加”。
# 因为 prompt 结束后 KV 长度会发生不连续跳变:
# - prompt ingest 过程中:KV cache 中存的是“完整 prompt 前缀”
# - prompt ingest 完成后:KV cache 中应变为“压缩后的 prompt”
# 实际的 in-place compaction 在 worker 侧 decode 前执行;这里先把记账值更新到位。
kept_prompt_total = compute_prompt_keep_len(
prompt_len=prompt_end,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
# 如果本 step 跨过 prompt_end,prompt_end 之后的 token 属于 decode,仍需全保留。
kept_decode = max(0, end_pos - max(start_pos, prompt_end))
request.num_kv_tokens = kept_prompt_total + kept_decode
return
# 非 chunked prefill(scheme 1/2):每个 step 内做 token-shared 的选择。
# - decode token:始终全保留;
# - prompt token:只保留 must-keep(protected prefix/suffix/可选最后token)
# + 本 step Top-K 选中的部分。
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
# 本 step 的逻辑区间内,prompt token 里“必须保留”的部分:
# protected_prefix / protected_suffix /(可选)最后一个 prompt token。
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
)
# 本 step 通过 Top-K 策略“额外保留”的 prompt token 数。
# 预算定义在 prompt 的“非保护区”上,并由 `compute_topk_budget_step` 按 step 分摊。
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
# 本 step 结束后:KV cache 实际长度按“保留的 KV 条目数”推进。
request.num_kv_tokens += (kept_decode + kept_prompt_must_keep +
kept_prompt_topk)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment