Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
20e489ea
Unverified
Commit
20e489ea
authored
Apr 27, 2025
by
Lily Liu
Committed by
GitHub
Apr 27, 2025
Browse files
[V1][Spec Decode] Make eagle compatible with prefix caching. (#17137)
Signed-off-by:
LiuXiaoxuanPKU
<
lilyliupku@gmail.com
>
parent
4213475e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
9 deletions
+81
-9
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+57
-0
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+0
-1
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+10
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+14
-8
No files found.
tests/v1/core/test_prefix_caching.py
View file @
20e489ea
...
...
@@ -719,3 +719,60 @@ def test_prefix_cache_stats_disabled():
# Ensure prefix_cache_stats remains None
assert
manager
.
prefix_cache_stats
is
None
def
test_eagle_enabled_removes_last_block
():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""
block_size
=
16
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
num_blocks
=
10
),
max_model_len
=
8192
,
enable_caching
=
True
,
use_eagle
=
True
,
)
# Request with 3 full blocks (48 tokens)
token_ids
=
[
0
]
*
(
3
*
block_size
)
req
=
make_request
(
"divisible_request"
,
token_ids
)
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
computed_blocks
)
manager
.
free
(
req
)
# New request with same tokens + Eagle enabled
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Should retain 2 blocks:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. last_block_hash is not None → Eagle pop is not SKIPPED
assert
len
(
computed_blocks
)
==
1
assert
num_tokens
==
1
*
block_size
# 32 tokens
def
test_eagle_with_partial_blocks
():
"""Test Eagle behavior with requests containing partial blocks."""
block_size
=
16
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
num_blocks
=
10
),
max_model_len
=
8192
,
enable_caching
=
True
,
use_eagle
=
True
,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
)
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
computed_blocks
)
manager
.
free
(
req
)
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
)
==
1
assert
num_tokens
==
1
*
block_size
tests/v1/e2e/test_spec_decode.py
View file @
20e489ea
...
...
@@ -44,7 +44,6 @@ def test_prompts():
@
pytest
.
fixture
def
sampling_config
():
# Only support greedy for now
return
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
,
ignore_eos
=
False
)
...
...
vllm/v1/core/kv_cache_manager.py
View file @
20e489ea
...
...
@@ -25,6 +25,7 @@ class KVCacheManager:
max_model_len
:
int
,
enable_caching
:
bool
=
True
,
caching_hash_algo
:
str
=
"builtin"
,
use_eagle
:
bool
=
False
,
log_stats
:
bool
=
False
,
)
->
None
:
assert
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
,
(
...
...
@@ -38,6 +39,7 @@ class KVCacheManager:
self
.
enable_caching
=
enable_caching
self
.
caching_hash_fn
=
sha256
if
caching_hash_algo
==
"sha256"
else
hash
self
.
use_eagle
=
use_eagle
self
.
log_stats
=
log_stats
# FIXME: make prefix cache stats conditional on log_stats
self
.
prefix_cache_stats
=
PrefixCacheStats
()
if
log_stats
else
None
...
...
@@ -134,6 +136,14 @@ class KVCacheManager:
computed_blocks
=
(
self
.
specialized_manager
.
find_longest_cache_hit
(
block_hashes
))
if
self
.
use_eagle
and
len
(
computed_blocks
)
>
0
:
# Drop the last matched block if (1) eagle is enabled and
# (2) there is a cache hit.
# This is to recompute the last block to get the required
# hidden states for eagle drafting head.
computed_blocks
.
pop
()
if
self
.
log_stats
:
assert
self
.
prefix_cache_stats
is
not
None
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
20e489ea
...
...
@@ -74,13 +74,6 @@ class Scheduler(SchedulerInterface):
num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
assert
num_gpu_blocks
is
not
None
and
num_gpu_blocks
>
0
# Create the KV cache manager.
self
.
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
kv_cache_config
,
max_model_len
=
self
.
max_model_len
,
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
,
caching_hash_algo
=
self
.
cache_config
.
prefix_caching_hash_algo
,
log_stats
=
self
.
log_stats
)
self
.
block_size
=
self
.
cache_config
.
block_size
# req_id -> Request
...
...
@@ -123,12 +116,24 @@ class Scheduler(SchedulerInterface):
cache_size
=
encoder_cache_size
)
speculative_config
=
vllm_config
.
speculative_config
self
.
use_eagle
=
False
self
.
num_spec_tokens
=
self
.
num_lookahead_tokens
=
0
if
speculative_config
:
self
.
num_spec_tokens
=
speculative_config
.
num_speculative_tokens
if
speculative_config
.
use_eagle
():
self
.
use_eagle
=
True
self
.
num_lookahead_tokens
=
self
.
num_spec_tokens
# Create the KV cache manager.
self
.
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
kv_cache_config
,
max_model_len
=
self
.
max_model_len
,
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
,
caching_hash_algo
=
self
.
cache_config
.
prefix_caching_hash_algo
,
use_eagle
=
self
.
use_eagle
,
log_stats
=
self
.
log_stats
)
def
schedule
(
self
)
->
SchedulerOutput
:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
...
...
@@ -317,7 +322,8 @@ class Scheduler(SchedulerInterface):
# Get already-cached tokens.
computed_blocks
,
num_computed_tokens
=
\
self
.
kv_cache_manager
.
get_computed_blocks
(
request
)
self
.
kv_cache_manager
.
get_computed_blocks
(
request
)
# Get externally-cached tokens if using a KVConnector.
num_external_tokens
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment