Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2fde0fa2
Commit
2fde0fa2
authored
Jan 21, 2026
by
laibao
Browse files
feat: kvpress新增调度层 KV 压缩逻辑
parent
eef99f73
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
183 additions
and
10 deletions
+183
-10
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+183
-10
No files found.
vllm/v1/core/sched/scheduler.py
View file @
2fde0fa2
...
...
@@ -28,12 +28,16 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
from
vllm.v1.core.sched.utils
import
check_stop
from
vllm.v1.engine
import
(
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
SlidingWindowSpec
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.kv_compression.budget
import
(
compute_prompt_keep_len
,
compute_topk_budget_step
,
count_prompt_must_keep_in_range
)
from
vllm.platforms
import
current_platform
from
vllm
import
envs
logger
=
init_logger
(
__name__
)
...
...
@@ -156,6 +160,53 @@ class Scheduler(SchedulerInterface):
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
full_cuda_graph
=
self
.
compilation_config
.
full_cuda_graph
self
.
use_mla
=
vllm_config
.
model_config
.
use_mla
# KV compression is only supported on CUDA/ROCm in this fork.
# Other backends (TPU/CPU/XPU/HPU/Neuron/...) do not plumb the
# num_kv_tokens-based slot mapping/metadata and can produce incorrect
# cache mappings if enabled.
self
.
kv_compression_enabled
=
(
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
current_platform
.
is_cuda_alike
())
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
not
self
.
kv_compression_enabled
:
logger
.
warning_once
(
"KV compression is only supported on CUDA/ROCm; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1 on this platform."
)
if
self
.
kv_compression_enabled
:
if
envs
.
VLLM_KV_COMPRESSION_POLICY
!=
"topk"
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_POLICY must be 'topk'."
)
if
any
(
isinstance
(
group
.
kv_cache_spec
,
SlidingWindowSpec
)
for
group
in
kv_cache_config
.
kv_cache_groups
):
raise
ValueError
(
"KV compression is incompatible with sliding window "
"attention."
)
if
self
.
cache_config
.
enable_prefix_caching
:
raise
ValueError
(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression."
)
if
self
.
full_cuda_graph
:
raise
ValueError
(
"KV compression is currently incompatible with full CUDA "
"graph mode."
)
if
self
.
speculative_config
is
not
None
:
raise
ValueError
(
"KV compression is currently incompatible with "
"speculative decoding."
)
if
envs
.
VLLM_KV_COMPRESSION_PROMPT_BUDGET
<
-
1
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1."
)
if
not
(
0.0
<=
envs
.
VLLM_KV_COMPRESSION_PROMPT_RATIO
<=
1.0
):
raise
ValueError
(
"VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1]."
)
if
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
<
0
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0."
)
if
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
<
0
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0."
)
if
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
<
1
:
raise
ValueError
(
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1."
)
# Create the KV cache manager.
self
.
kv_cache_manager
=
KVCacheManager
(
...
...
@@ -207,6 +258,8 @@ class Scheduler(SchedulerInterface):
encoder_budget
=
self
.
max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens
:
dict
[
str
,
list
[
int
]]
=
{}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids
:
set
[
str
]
=
set
()
# For logging.
scheduled_timestamp
=
time
.
monotonic
()
...
...
@@ -274,6 +327,13 @@ class Scheduler(SchedulerInterface):
num_new_tokens
+
request
.
num_computed_tokens
-
request
.
num_tokens
,
0
)
if
(
self
.
kv_compression_enabled
and
envs
.
VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and
request
.
num_computed_tokens
==
request
.
num_prompt_tokens
and
self
.
kv_cache_manager
.
truncate_to_num_tokens
(
request
.
request_id
,
request
.
num_kv_tokens
)):
force_replace_block_ids
.
add
(
request
.
request_id
)
while
True
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
...
...
@@ -295,6 +355,7 @@ class Scheduler(SchedulerInterface):
self
.
kv_cache_manager
.
free
(
preempted_req
)
preempted_req
.
status
=
RequestStatus
.
PREEMPTED
preempted_req
.
num_computed_tokens
=
0
preempted_req
.
num_kv_tokens
=
0
if
self
.
log_stats
:
preempted_req
.
record_event
(
EngineCoreEventType
.
PREEMPTED
,
scheduled_timestamp
)
...
...
@@ -321,8 +382,12 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids
[
request
.
request_id
]
=
req_index
req_to_new_block_ids
[
request
.
request_id
]
=
(
new_blocks
.
get_block_ids
())
if
request
.
request_id
in
force_replace_block_ids
:
req_to_new_block_ids
[
request
.
request_id
]
=
(
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
))
else
:
req_to_new_block_ids
[
request
.
request_id
]
=
(
new_blocks
.
get_block_ids
())
num_scheduled_tokens
[
request
.
request_id
]
=
num_new_tokens
token_budget
-=
num_new_tokens
req_index
+=
1
...
...
@@ -532,6 +597,8 @@ class Scheduler(SchedulerInterface):
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
request
.
num_computed_tokens
=
num_computed_tokens
if
not
self
.
kv_compression_enabled
:
request
.
num_kv_tokens
=
num_computed_tokens
# Count the number of prefix cached tokens.
if
request
.
num_cached_tokens
<
0
:
request
.
num_cached_tokens
=
num_computed_tokens
...
...
@@ -586,6 +653,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens
,
scheduled_spec_decode_tokens
,
req_to_new_block_ids
,
force_replace_block_ids
=
force_replace_block_ids
,
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
new_reqs_data
,
...
...
@@ -645,6 +713,16 @@ class Scheduler(SchedulerInterface):
encoder_budget
=
self
.
max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens
:
dict
[
str
,
list
[
int
]]
=
{}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids
:
set
[
str
]
=
set
()
# Track the LoRAs in this step to respect max_loras when scheduling
# waiting requests first.
scheduled_loras
:
set
[
int
]
=
set
()
if
self
.
lora_config
:
scheduled_loras
=
set
(
req
.
lora_request
.
lora_int_id
for
req
in
self
.
running
if
req
.
lora_request
and
req
.
lora_request
.
lora_int_id
>
0
)
assert
len
(
scheduled_loras
)
<=
self
.
lora_config
.
max_loras
# For logging.
scheduled_timestamp
=
time
.
monotonic
()
...
...
@@ -826,6 +904,8 @@ class Scheduler(SchedulerInterface):
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
request
.
num_computed_tokens
=
num_computed_tokens
if
not
self
.
kv_compression_enabled
:
request
.
num_kv_tokens
=
num_computed_tokens
# Count the number of prefix cached tokens.
if
request
.
num_cached_tokens
<
0
:
request
.
num_cached_tokens
=
num_computed_tokens
...
...
@@ -894,6 +974,14 @@ class Scheduler(SchedulerInterface):
num_new_tokens
+
request
.
num_computed_tokens
-
request
.
num_tokens
,
0
)
if
(
self
.
kv_compression_enabled
and
envs
.
VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and
request
.
num_computed_tokens
==
request
.
num_prompt_tokens
and
self
.
kv_cache_manager
.
truncate_to_num_tokens
(
request
.
request_id
,
request
.
num_kv_tokens
)):
force_replace_block_ids
.
add
(
request
.
request_id
)
while
True
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
...
...
@@ -915,6 +1003,7 @@ class Scheduler(SchedulerInterface):
self
.
kv_cache_manager
.
free
(
preempted_req
)
preempted_req
.
status
=
RequestStatus
.
PREEMPTED
preempted_req
.
num_computed_tokens
=
0
preempted_req
.
num_kv_tokens
=
0
if
self
.
log_stats
:
preempted_req
.
record_event
(
EngineCoreEventType
.
PREEMPTED
,
scheduled_timestamp
)
...
...
@@ -941,8 +1030,12 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids
[
request
.
request_id
]
=
req_index
req_to_new_block_ids
[
request
.
request_id
]
=
(
new_blocks
.
get_block_ids
())
if
request
.
request_id
in
force_replace_block_ids
:
req_to_new_block_ids
[
request
.
request_id
]
=
(
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
))
else
:
req_to_new_block_ids
[
request
.
request_id
]
=
(
new_blocks
.
get_block_ids
())
num_scheduled_tokens
[
request
.
request_id
]
=
num_new_tokens
token_budget
-=
num_new_tokens
req_index
+=
1
...
...
@@ -1014,6 +1107,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens
,
scheduled_spec_decode_tokens
,
req_to_new_block_ids
,
force_replace_block_ids
=
force_replace_block_ids
,
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
new_reqs_data
,
...
...
@@ -1076,8 +1170,81 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
for
req_id
,
num_scheduled_token
in
num_scheduled_tokens
.
items
():
request
=
self
.
requests
[
req_id
]
start_pos
=
request
.
num_computed_tokens
request
.
num_computed_tokens
+=
num_scheduled_token
if
not
self
.
kv_compression_enabled
:
# Keep KV length in sync with logical length when compression
# is disabled (default vLLM behavior).
request
.
num_kv_tokens
+=
num_scheduled_token
continue
# When KV compression is enabled, only keep a subset of prompt
# tokens. Decode tokens are always kept.
prompt_ratio
=
envs
.
VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget
=
envs
.
VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix
=
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix
=
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last
=
envs
.
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos
=
request
.
num_computed_tokens
prompt_end
=
request
.
num_prompt_tokens
# Chunked prefill: do not change the prompt KV length mid-prefill.
# Otherwise, the next prefill chunk would attend to a truncated
# history (semantic change / quality collapse). Instead, keep the
# full prompt KV until the prompt is fully ingested, then apply a
# one-shot prompt compaction before decode.
if
self
.
scheduler_config
.
chunked_prefill_enabled
:
if
start_pos
>=
prompt_end
:
# Decode token(s): keep all.
request
.
num_kv_tokens
+=
num_scheduled_token
continue
if
end_pos
<
prompt_end
:
# Prompt is still being ingested: keep all tokens for now.
request
.
num_kv_tokens
+=
num_scheduled_token
continue
# This step finishes the prompt (and may include decode tokens
# in rare cases). Apply the final prompt compression length.
kept_prompt_total
=
compute_prompt_keep_len
(
prompt_len
=
prompt_end
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
kept_decode
=
max
(
0
,
end_pos
-
max
(
start_pos
,
prompt_end
))
request
.
num_kv_tokens
=
kept_prompt_total
+
kept_decode
continue
# Decode token(s): keep all.
decode_start
=
max
(
start_pos
,
prompt_end
)
kept_decode
=
max
(
0
,
end_pos
-
decode_start
)
kept_prompt_must_keep
=
count_prompt_must_keep_in_range
(
prompt_len
=
prompt_end
,
start_pos
=
start_pos
,
end_pos
=
end_pos
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
)
kept_prompt_topk
=
compute_topk_budget_step
(
prompt_len
=
prompt_end
,
start_pos
=
start_pos
,
end_pos
=
end_pos
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
request
.
num_kv_tokens
+=
(
kept_decode
+
kept_prompt_must_keep
+
kept_prompt_topk
)
# Clear the finished request IDs.
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
...
...
@@ -1091,11 +1258,16 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens
:
dict
[
str
,
int
],
spec_decode_tokens
:
dict
[
str
,
list
[
int
]],
req_to_new_block_ids
:
dict
[
str
,
tuple
[
list
[
int
],
...]],
*
,
force_replace_block_ids
:
Optional
[
set
[
str
]]
=
None
,
)
->
CachedRequestData
:
req_ids
:
list
[
str
]
=
[]
new_token_ids
:
list
[
list
[
int
]]
=
[]
new_block_ids
:
list
[
tuple
[
list
[
int
],
...]]
=
[]
num_computed_tokens
:
list
[
int
]
=
[]
num_kv_tokens
:
list
[
int
]
=
[]
resumed_from_preemption
:
list
[
bool
]
=
[]
force_replace_block_ids
=
force_replace_block_ids
or
set
()
for
req
in
itertools
.
chain
(
running_reqs
,
resumed_reqs
):
req_id
=
req
.
request_id
...
...
@@ -1111,10 +1283,9 @@ class Scheduler(SchedulerInterface):
new_token_ids
.
append
(
token_ids
)
new_block_ids
.
append
(
req_to_new_block_ids
[
req_id
])
num_computed_tokens
.
append
(
req
.
num_computed_tokens
)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption
=
[
False
]
*
len
(
running_reqs
)
resumed_from_preemption
+=
[
True
]
*
len
(
resumed_reqs
)
num_kv_tokens
.
append
(
req
.
num_kv_tokens
)
resumed_from_preemption
.
append
(
(
req
in
resumed_reqs
)
or
(
req_id
in
force_replace_block_ids
))
return
CachedRequestData
(
req_ids
=
req_ids
,
...
...
@@ -1122,6 +1293,7 @@ class Scheduler(SchedulerInterface):
new_token_ids
=
new_token_ids
,
new_block_ids
=
new_block_ids
,
num_computed_tokens
=
num_computed_tokens
,
num_kv_tokens
=
num_kv_tokens
,
)
def
_try_schedule_encoder_inputs
(
...
...
@@ -1567,6 +1739,7 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling.
request
.
num_computed_tokens
=
num_computed_tokens
request
.
num_kv_tokens
=
num_computed_tokens
# Return that we are ready.
self
.
finished_recving_kv_req_ids
.
remove
(
request
.
request_id
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment