Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c12a765
Unverified
Commit
7c12a765
authored
Jul 09, 2025
by
Woosuk Kwon
Committed by
GitHub
Jul 09, 2025
Browse files
[Misc] Simplify the prefix caching logic on draft tokens (#20701)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
cd587c93
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
11 deletions
+10
-11
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+10
-6
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+0
-5
No files found.
vllm/v1/core/kv_cache_manager.py
View file @
7c12a765
...
@@ -190,7 +190,6 @@ class KVCacheManager:
...
@@ -190,7 +190,6 @@ class KVCacheManager:
num_new_tokens
:
int
,
num_new_tokens
:
int
,
num_new_computed_tokens
:
int
=
0
,
num_new_computed_tokens
:
int
=
0
,
new_computed_blocks
:
Optional
[
KVCacheBlocks
]
=
None
,
new_computed_blocks
:
Optional
[
KVCacheBlocks
]
=
None
,
num_draft_tokens
:
int
=
0
,
num_lookahead_tokens
:
int
=
0
,
num_lookahead_tokens
:
int
=
0
,
delay_cache_blocks
:
bool
=
False
,
delay_cache_blocks
:
bool
=
False
,
)
->
Optional
[
KVCacheBlocks
]:
)
->
Optional
[
KVCacheBlocks
]:
...
@@ -286,12 +285,17 @@ class KVCacheManager:
...
@@ -286,12 +285,17 @@ class KVCacheManager:
if
not
self
.
enable_caching
or
delay_cache_blocks
:
if
not
self
.
enable_caching
or
delay_cache_blocks
:
return
KVCacheBlocks
(
new_blocks
)
return
KVCacheBlocks
(
new_blocks
)
# Speculated tokens might be rejected in the future, so we does
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
# not cache any speculated tokens. We only cache blocks with
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
# generated (accepted) tokens.
# draft tokens that could be rejected). Therefore, we cap the number
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
num_tokens_to_cache
=
min
(
num_computed_tokens
+
num_new_tokens
,
request
.
num_tokens
)
self
.
coordinator
.
cache_blocks
(
self
.
coordinator
.
cache_blocks
(
request
,
self
.
req_to_block_hashes
[
request
.
request_id
],
request
,
num_computed_tokens
+
num_new_tokens
-
num_draft_tokens
)
self
.
req_to_block_hashes
[
request
.
request_id
],
num_tokens_to_cache
,
)
return
KVCacheBlocks
(
new_blocks
)
return
KVCacheBlocks
(
new_blocks
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
7c12a765
...
@@ -241,15 +241,10 @@ class Scheduler(SchedulerInterface):
...
@@ -241,15 +241,10 @@ class Scheduler(SchedulerInterface):
req_index
+=
1
req_index
+=
1
continue
continue
num_draft_tokens
=
max
(
num_new_tokens
+
request
.
num_computed_tokens
-
request
.
num_tokens
,
0
)
while
True
:
while
True
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
request
,
num_new_tokens
,
num_new_tokens
,
num_draft_tokens
=
num_draft_tokens
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
)
num_lookahead_tokens
=
self
.
num_lookahead_tokens
)
if
new_blocks
is
None
:
if
new_blocks
is
None
:
# The request cannot be scheduled.
# The request cannot be scheduled.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment