Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b0911b24
Commit
b0911b24
authored
Feb 24, 2026
by
laibao
Browse files
feat(kvpress): 增加调度侧 KV 长度记账
parent
87b788bd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
259 additions
and
0 deletions
+259
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+101
-0
vllm/v1/kv_compression/scheduler_accounting.py
vllm/v1/kv_compression/scheduler_accounting.py
+158
-0
No files found.
vllm/v1/core/sched/scheduler.py
View file @
b0911b24
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsReader
,
RoutedExpertsReader
,
)
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.platforms
import
current_platform
from
vllm.v1.core.encoder_cache_manager
import
(
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
EncoderCacheManager
,
EncoderDecoderCacheManager
,
EncoderDecoderCacheManager
,
...
@@ -50,6 +51,12 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
...
@@ -50,6 +51,12 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
from
vllm.v1.core.sched.utils
import
check_stop
,
remove_all
from
vllm.v1.core.sched.utils
import
check_stop
,
remove_all
from
vllm.v1.engine
import
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.engine
import
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
MambaSpec
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
MambaSpec
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.kv_cache_interface
import
SlidingWindowSpec
from
vllm.v1.kv_compression.scheduler_accounting
import
(
maybe_init_num_kv_tokens_on_running_transition
,
update_num_kv_tokens_after_schedule
,
)
from
vllm.v1.metrics.perf
import
ModelMetrics
,
PerfStats
from
vllm.v1.metrics.perf
import
ModelMetrics
,
PerfStats
from
vllm.v1.metrics.stats
import
PrefixCacheStats
,
SchedulerStats
from
vllm.v1.metrics.stats
import
PrefixCacheStats
,
SchedulerStats
from
vllm.v1.outputs
import
DraftTokenIds
,
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.outputs
import
DraftTokenIds
,
KVConnectorOutput
,
ModelRunnerOutput
...
@@ -204,6 +211,7 @@ class Scheduler(SchedulerInterface):
...
@@ -204,6 +211,7 @@ class Scheduler(SchedulerInterface):
)
)
speculative_config
=
vllm_config
.
speculative_config
speculative_config
=
vllm_config
.
speculative_config
self
.
speculative_config
=
speculative_config
self
.
use_eagle
=
False
self
.
use_eagle
=
False
self
.
num_spec_tokens
=
self
.
num_lookahead_tokens
=
0
self
.
num_spec_tokens
=
self
.
num_lookahead_tokens
=
0
if
speculative_config
:
if
speculative_config
:
...
@@ -218,6 +226,70 @@ class Scheduler(SchedulerInterface):
...
@@ -218,6 +226,70 @@ class Scheduler(SchedulerInterface):
self
.
full_cuda_graph
=
self
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
FULL
self
.
full_cuda_graph
=
self
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
FULL
self
.
use_mla
=
vllm_config
.
model_config
.
use_mla
self
.
use_mla
=
vllm_config
.
model_config
.
use_mla
# KV compression is a cross-component feature: Scheduler handles gate +
# accounting; Worker generates slot_mapping/metadata; attention backend
# performs scoring/Top-K selection and KV rewrite/compaction.
#
# Gate early to avoid enabling KV-compression accounting/slot mapping
# on unsupported platforms or incompatible feature combinations.
self
.
kv_compression_enabled
=
(
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
current_platform
.
is_cuda_alike
()
)
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
not
self
.
kv_compression_enabled
:
logger
.
warning_once
(
"KV compression is only supported on CUDA/ROCm; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1 on this platform."
)
if
self
.
kv_compression_enabled
:
if
envs
.
VLLM_KV_COMPRESSION_POLICY
!=
"topk"
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_POLICY must be 'topk'."
)
if
any
(
isinstance
(
group
.
kv_cache_spec
,
SlidingWindowSpec
)
for
group
in
kv_cache_config
.
kv_cache_groups
):
raise
ValueError
(
"KV compression is incompatible with sliding window attention."
)
if
self
.
cache_config
.
enable_prefix_caching
:
raise
ValueError
(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression."
)
if
self
.
full_cuda_graph
:
raise
ValueError
(
"KV compression is currently incompatible with full CUDA graph mode."
)
if
self
.
speculative_config
is
not
None
:
raise
ValueError
(
"KV compression is currently incompatible with speculative decoding."
)
if
self
.
dcp_world_size
>
1
or
self
.
pcp_world_size
>
1
:
raise
ValueError
(
"KV compression is currently incompatible with context parallelism "
"(dcp_world_size > 1 or pcp_world_size > 1)."
)
backend
=
self
.
vllm_config
.
attention_config
.
backend
if
backend
is
not
None
and
backend
!=
AttentionBackendEnum
.
FLASH_ATTN
:
raise
ValueError
(
"KV compression currently requires the FLASH_ATTN backend. "
f
"Got attention_config.backend=
{
backend
}
."
)
if
envs
.
VLLM_KV_COMPRESSION_PROMPT_BUDGET
<
-
1
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1."
)
if
not
(
0.0
<=
envs
.
VLLM_KV_COMPRESSION_PROMPT_RATIO
<=
1.0
):
raise
ValueError
(
"VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1]."
)
if
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
<
0
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0."
)
if
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
<
0
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0."
)
if
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
<
1
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1."
)
# Create the KV cache manager.
# Create the KV cache manager.
self
.
kv_cache_manager
=
KVCacheManager
(
self
.
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
...
@@ -775,6 +847,11 @@ class Scheduler(SchedulerInterface):
...
@@ -775,6 +847,11 @@ class Scheduler(SchedulerInterface):
token_budget
-=
num_new_tokens
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
request
.
status
=
RequestStatus
.
RUNNING
request
.
num_computed_tokens
=
num_computed_tokens
request
.
num_computed_tokens
=
num_computed_tokens
maybe_init_num_kv_tokens_on_running_transition
(
request
=
request
,
num_computed_tokens
=
num_computed_tokens
,
kv_compression_enabled
=
self
.
kv_compression_enabled
,
)
# Count the number of prefix cached tokens.
# Count the number of prefix cached tokens.
if
request
.
num_cached_tokens
<
0
:
if
request
.
num_cached_tokens
<
0
:
request
.
num_cached_tokens
=
num_computed_tokens
request
.
num_cached_tokens
=
num_computed_tokens
...
@@ -1163,6 +1240,11 @@ class Scheduler(SchedulerInterface):
...
@@ -1163,6 +1240,11 @@ class Scheduler(SchedulerInterface):
token_budget
-=
num_new_tokens
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
request
.
status
=
RequestStatus
.
RUNNING
request
.
num_computed_tokens
=
num_computed_tokens
request
.
num_computed_tokens
=
num_computed_tokens
maybe_init_num_kv_tokens_on_running_transition
(
request
=
request
,
num_computed_tokens
=
num_computed_tokens
,
kv_compression_enabled
=
self
.
kv_compression_enabled
,
)
# Count the number of prefix cached tokens.
# Count the number of prefix cached tokens.
if
request
.
num_cached_tokens
<
0
:
if
request
.
num_cached_tokens
<
0
:
request
.
num_cached_tokens
=
num_computed_tokens
request
.
num_cached_tokens
=
num_computed_tokens
...
@@ -1499,6 +1581,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1499,6 +1581,7 @@ class Scheduler(SchedulerInterface):
self
.
encoder_cache_manager
.
free
(
request
)
self
.
encoder_cache_manager
.
free
(
request
)
request
.
status
=
RequestStatus
.
PREEMPTED
request
.
status
=
RequestStatus
.
PREEMPTED
request
.
num_computed_tokens
=
0
request
.
num_computed_tokens
=
0
request
.
num_kv_tokens
=
0
request
.
spec_token_ids
.
clear
()
request
.
spec_token_ids
.
clear
()
request
.
num_preemptions
+=
1
request
.
num_preemptions
+=
1
if
self
.
log_stats
:
if
self
.
log_stats
:
...
@@ -1520,7 +1603,15 @@ class Scheduler(SchedulerInterface):
...
@@ -1520,7 +1603,15 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
for
req_id
,
num_scheduled_token
in
num_scheduled_tokens
.
items
():
for
req_id
,
num_scheduled_token
in
num_scheduled_tokens
.
items
():
request
=
self
.
requests
[
req_id
]
request
=
self
.
requests
[
req_id
]
start_pos
=
request
.
num_computed_tokens
request
.
num_computed_tokens
+=
num_scheduled_token
request
.
num_computed_tokens
+=
num_scheduled_token
update_num_kv_tokens_after_schedule
(
request
=
request
,
start_pos
=
start_pos
,
num_scheduled_token
=
num_scheduled_token
,
chunked_prefill_enabled
=
self
.
scheduler_config
.
enable_chunked_prefill
,
kv_compression_enabled
=
self
.
kv_compression_enabled
,
)
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
# may be updated again in _update_from_output for speculative
# may be updated again in _update_from_output for speculative
...
@@ -1593,6 +1684,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1593,6 +1684,7 @@ class Scheduler(SchedulerInterface):
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]
|
None
]
=
[]
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]
|
None
]
=
[]
all_token_ids
:
dict
[
str
,
list
[
int
]]
=
{}
all_token_ids
:
dict
[
str
,
list
[
int
]]
=
{}
num_computed_tokens
:
list
[
int
]
=
[]
num_computed_tokens
:
list
[
int
]
=
[]
num_kv_tokens
:
list
[
int
]
=
[]
num_output_tokens
:
list
[
int
]
=
[]
num_output_tokens
:
list
[
int
]
=
[]
resumed_req_ids
=
set
()
resumed_req_ids
=
set
()
...
@@ -1623,6 +1715,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1623,6 +1715,7 @@ class Scheduler(SchedulerInterface):
req_to_new_blocks
[
req_id
].
get_block_ids
(
allow_none
=
True
)
req_to_new_blocks
[
req_id
].
get_block_ids
(
allow_none
=
True
)
)
)
num_computed_tokens
.
append
(
req
.
num_computed_tokens
)
num_computed_tokens
.
append
(
req
.
num_computed_tokens
)
num_kv_tokens
.
append
(
req
.
num_kv_tokens
)
num_output_tokens
.
append
(
num_output_tokens
.
append
(
req
.
num_output_tokens
+
req
.
num_output_placeholders
req
.
num_output_tokens
+
req
.
num_output_placeholders
)
)
...
@@ -1634,6 +1727,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1634,6 +1727,7 @@ class Scheduler(SchedulerInterface):
all_token_ids
=
all_token_ids
,
all_token_ids
=
all_token_ids
,
new_block_ids
=
new_block_ids
,
new_block_ids
=
new_block_ids
,
num_computed_tokens
=
num_computed_tokens
,
num_computed_tokens
=
num_computed_tokens
,
num_kv_tokens
=
num_kv_tokens
,
num_output_tokens
=
num_output_tokens
,
num_output_tokens
=
num_output_tokens
,
)
)
...
@@ -1892,6 +1986,8 @@ class Scheduler(SchedulerInterface):
...
@@ -1892,6 +1986,8 @@ class Scheduler(SchedulerInterface):
# tokens.
# tokens.
if
request
.
num_computed_tokens
>
0
:
if
request
.
num_computed_tokens
>
0
:
request
.
num_computed_tokens
-=
num_rejected
request
.
num_computed_tokens
-=
num_rejected
if
request
.
num_kv_tokens
>
0
:
request
.
num_kv_tokens
-=
num_rejected
# If async scheduling, num_output_placeholders also includes
# If async scheduling, num_output_placeholders also includes
# the scheduled spec tokens count and so is similarly adjusted.
# the scheduled spec tokens count and so is similarly adjusted.
if
request
.
num_output_placeholders
>
0
:
if
request
.
num_output_placeholders
>
0
:
...
@@ -2519,6 +2615,11 @@ class Scheduler(SchedulerInterface):
...
@@ -2519,6 +2615,11 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling.
# Update the request state for scheduling.
request
.
num_computed_tokens
=
num_computed_tokens
request
.
num_computed_tokens
=
num_computed_tokens
maybe_init_num_kv_tokens_on_running_transition
(
request
=
request
,
num_computed_tokens
=
num_computed_tokens
,
kv_compression_enabled
=
self
.
kv_compression_enabled
,
)
# Return that we are ready.
# Return that we are ready.
self
.
finished_recving_kv_req_ids
.
remove
(
request
.
request_id
)
self
.
finished_recving_kv_req_ids
.
remove
(
request
.
request_id
)
...
...
vllm/v1/kv_compression/scheduler_accounting.py
0 → 100644
View file @
b0911b24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.budget
import
(
compute_prompt_keep_len
,
compute_topk_budget_step
,
count_prompt_must_keep_in_range
)
def
maybe_init_num_kv_tokens_on_running_transition
(
*
,
request
:
Any
,
num_computed_tokens
:
int
,
kv_compression_enabled
:
bool
,
)
->
None
:
"""在 request 切换为 RUNNING 时,必要时初始化 `request.num_kv_tokens`。
- 未开启 KV compression:KV 实际长度始终等于逻辑长度,直接令
`num_kv_tokens = num_computed_tokens` 即可。
- 开启 KV compression:大多数请求从 0 token 开始(`num_computed_tokens == 0`),
不需要额外初始化;但某些路径(例如 KV connector / cache hit)可能让一个请求在
进入 RUNNING 时已经“预先拥有”一段已计算的 token。如果不把 `num_kv_tokens`
初始化到同样的值,后续 KV 写入偏移(基于 `num_kv_tokens`)会从 0 开始,导致
slot_mapping/KV cache 写入错位。
"""
if
not
kv_compression_enabled
:
request
.
num_kv_tokens
=
num_computed_tokens
return
if
getattr
(
request
,
"num_kv_tokens"
,
0
)
==
0
and
num_computed_tokens
>
0
:
request
.
num_kv_tokens
=
num_computed_tokens
def
update_num_kv_tokens_after_schedule
(
*
,
request
:
Any
,
start_pos
:
int
,
num_scheduled_token
:
int
,
chunked_prefill_enabled
:
bool
,
kv_compression_enabled
:
bool
,
)
->
None
:
"""在一次调度(一个 step)之后推进 `request.num_kv_tokens`。
这是 KV compression 的“调度侧记账”函数(不做打分/TopK/重写 KV),目的仅是
让 Scheduler 维护出“KV cache 实际长度”,以便后续:
- KV block 分配(allocate_slots)
- Worker 侧 slot_mapping / KV 写入偏移
- attention metadata 里的 `seq_lens`
都基于正确的 KV 长度工作。
关键概念(针对单个 request):
- `num_computed_tokens`:逻辑进度(token 位置 / RoPE index)。
- `num_kv_tokens`:KV cache 中“实际保留/存储”的 token 数。
开启 KV compression 后,prompt KV 可能被压缩(只保留一部分),因此
`num_kv_tokens` 可能小于 `num_computed_tokens`;但 decode token 始终全保留。
参数含义:
- `start_pos`:本 step 开始前的逻辑位置;本次新调度的逻辑区间为
`[start_pos, end_pos)`。
- `num_scheduled_token`:该 request 在本 step 被调度的 token 数(可能同时包含
prompt token 和 decode token)。
- `chunked_prefill_enabled`:是否启用 chunked prefill。
- `kv_compression_enabled`:调度器 gate 后的“是否启用 KV compression”。
注意:
- 如果这里记账错了,worker 可能用错误的 KV 基址/长度生成 slot_mapping,
从而导致 KV cache 读写错位甚至越界。
- chunked prefill 模式下,为避免“下一段 prefill 看不到完整历史”导致质量崩溃,
prompt 的 KV compaction 会延后到 prompt 结束后一次性执行(prompt-end one-shot)。
因此 `num_kv_tokens` 的更新策略与非 chunked 模式不同。
"""
if
num_scheduled_token
<=
0
:
return
if
not
kv_compression_enabled
:
# 未开启压缩:KV cache 长度与逻辑长度始终一致,直接累加即可。
request
.
num_kv_tokens
+=
num_scheduled_token
return
# 开启 KV compression 后:prompt token 只保留子集;decode token 始终全保留。
prompt_ratio
=
envs
.
VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget
=
envs
.
VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix
=
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix
=
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last
=
envs
.
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos
=
start_pos
+
int
(
num_scheduled_token
)
prompt_end
=
int
(
getattr
(
request
,
"num_prompt_tokens"
,
0
))
# Chunked prefill:prefill 过程中不要“边 ingest 边压缩 prompt KV”。
# 否则下一段 chunk prefill 会注意力看不到完整历史(语义变化/质量崩溃)。
# 正确策略是:prefill 阶段暂时全量保留 prompt KV;等 prompt ingest 完成后,
# 在第一次 decode 前做一次性 prompt compaction(prompt-end one-shot)。
if
chunked_prefill_enabled
:
if
start_pos
>=
prompt_end
:
# 纯 decode 段:decode token 始终全保留,KV 长度直接累加。
request
.
num_kv_tokens
+=
num_scheduled_token
return
if
end_pos
<
prompt_end
:
# prompt 还没 ingest 完:暂时先全保留(不做 mid-prefill 压缩),KV 长度累加。
request
.
num_kv_tokens
+=
num_scheduled_token
return
# 重要:这里是“重置”而不是“累加”。
# 因为 prompt 结束后 KV 长度会发生不连续跳变:
# - prompt ingest 过程中:KV cache 中存的是“完整 prompt 前缀”
# - prompt ingest 完成后:KV cache 中应变为“压缩后的 prompt”
# 实际的 in-place compaction 在 worker 侧 decode 前执行;这里先把记账值更新到位。
kept_prompt_total
=
compute_prompt_keep_len
(
prompt_len
=
prompt_end
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
# 如果本 step 跨过 prompt_end,prompt_end 之后的 token 属于 decode,仍需全保留。
kept_decode
=
max
(
0
,
end_pos
-
max
(
start_pos
,
prompt_end
))
request
.
num_kv_tokens
=
kept_prompt_total
+
kept_decode
return
# 非 chunked prefill(scheme 1/2):每个 step 内做 token-shared 的选择。
# - decode token:始终全保留;
# - prompt token:只保留 must-keep(protected prefix/suffix/可选最后token)
# + 本 step Top-K 选中的部分。
decode_start
=
max
(
start_pos
,
prompt_end
)
kept_decode
=
max
(
0
,
end_pos
-
decode_start
)
# 本 step 的逻辑区间内,prompt token 里“必须保留”的部分:
# protected_prefix / protected_suffix /(可选)最后一个 prompt token。
kept_prompt_must_keep
=
count_prompt_must_keep_in_range
(
prompt_len
=
prompt_end
,
start_pos
=
start_pos
,
end_pos
=
end_pos
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
)
# 本 step 通过 Top-K 策略“额外保留”的 prompt token 数。
# 预算定义在 prompt 的“非保护区”上,并由 `compute_topk_budget_step` 按 step 分摊。
kept_prompt_topk
=
compute_topk_budget_step
(
prompt_len
=
prompt_end
,
start_pos
=
start_pos
,
end_pos
=
end_pos
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
# 本 step 结束后:KV cache 实际长度按“保留的 KV 条目数”推进。
request
.
num_kv_tokens
+=
(
kept_decode
+
kept_prompt_must_keep
+
kept_prompt_topk
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment