Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f49e5aff
Unverified
Commit
f49e5aff
authored
Apr 12, 2025
by
Lily Liu
Committed by
GitHub
Apr 12, 2025
Browse files
[V1][Spec Decode] KV cache slots for eagle heads (#16370)
Signed-off-by:
LiuXiaoxuanPKU
<
lilyliupku@gmail.com
>
parent
6c11ecf8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
98 additions
and
18 deletions
+98
-18
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+74
-12
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+12
-4
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+11
-2
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+1
-0
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
f49e5aff
...
@@ -7,6 +7,7 @@ from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
...
@@ -7,6 +7,7 @@ from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
GiB_bytes
,
sha256
from
vllm.utils
import
GiB_bytes
,
sha256
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
# yapf: disable
from
vllm.v1.core.kv_cache_utils
import
(
NONE_HASH
,
BlockHashType
,
from
vllm.v1.core.kv_cache_utils
import
(
NONE_HASH
,
BlockHashType
,
...
@@ -48,6 +49,18 @@ def make_request(request_id,
...
@@ -48,6 +49,18 @@ def make_request(request_id,
)
)
def
new_kv_cache_spec
(
block_size
=
16
,
num_kv_heads
=
2
,
head_size
=
64
,
dtype
=
torch
.
float32
,
use_mla
=
False
):
return
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
dtype
=
dtype
,
use_mla
=
use_mla
)
def
test_none_hash
():
def
test_none_hash
():
assert
NONE_HASH
is
not
None
assert
NONE_HASH
is
not
None
assert
isinstance
(
NONE_HASH
,
int
)
assert
isinstance
(
NONE_HASH
,
int
)
...
@@ -327,18 +340,6 @@ def test_metrics():
...
@@ -327,18 +340,6 @@ def test_metrics():
def
test_unify_kv_cache_configs
():
def
test_unify_kv_cache_configs
():
def
new_kv_cache_spec
(
block_size
=
16
,
num_kv_heads
=
2
,
head_size
=
64
,
dtype
=
torch
.
float32
,
use_mla
=
False
):
return
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
dtype
=
dtype
,
use_mla
=
use_mla
)
same_kv_cache_config
=
[
same_kv_cache_config
=
[
KVCacheConfig
(
KVCacheConfig
(
num_blocks
=
10
,
num_blocks
=
10
,
...
@@ -470,3 +471,64 @@ def test_estimate_max_model_len(model_id, max_model_len,
...
@@ -470,3 +471,64 @@ def test_estimate_max_model_len(model_id, max_model_len,
estimated_max_len
=
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
estimated_max_len
=
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
8
*
GiB_bytes
)
8
*
GiB_bytes
)
assert
estimated_max_len
==
want_estimated_max_len
assert
estimated_max_len
==
want_estimated_max_len
def
test_allocate_with_lookahead
():
"""Verify that lookahead tokens correctly affect block allocation"""
block_size
=
4
config
=
KVCacheConfig
(
num_blocks
=
10
,
tensors
=
{
"layer1"
:
KVCacheTensor
(
100
),
},
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
(
block_size
=
block_size
)),
],
)
request
=
make_request
(
request_id
=
0
,
prompt_token_ids
=
[],
mm_positions
=
None
,
mm_hashes
=
None
,
)
# Test case 1: Requires additional lookahead tokens
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
max_model_len
=
100
,
num_preallocate_tokens
=
0
)
blocks
=
kv_cache_manager
.
allocate_slots
(
request
,
num_tokens
=
3
,
num_lookahead_tokens
=
2
,
# Total required: 3+2=5 tokens
)
assert
len
(
blocks
)
==
2
# ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
max_model_len
=
100
,
num_preallocate_tokens
=
4
)
# num_preallocate_blocks = 4 // 4 - 2 // 4 = 1
# required_blocks = ceil((3 + 2) /4) = 2
# total_blocks = 1 + 2 = 3
blocks
=
kv_cache_manager
.
allocate_slots
(
request
,
num_tokens
=
3
,
num_lookahead_tokens
=
2
,
)
assert
len
(
blocks
)
==
3
# Test case 3: With precomputed blocks
# num_preallocate_blocks = 4 // 4 - 4 // 4 = 0
# required_blocks = ceil((3 + 4) / 4) = 2
# total_blocks = 0 + 2 = 2
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
max_model_len
=
100
,
num_preallocate_tokens
=
4
)
blocks
=
kv_cache_manager
.
allocate_slots
(
request
,
num_tokens
=
3
,
num_lookahead_tokens
=
4
,
)
assert
len
(
blocks
)
==
2
vllm/v1/core/kv_cache_manager.py
View file @
f49e5aff
...
@@ -164,7 +164,8 @@ class KVCacheManager:
...
@@ -164,7 +164,8 @@ class KVCacheManager:
self
,
self
,
request
:
Request
,
request
:
Request
,
num_tokens
:
int
,
num_tokens
:
int
,
new_computed_blocks
:
Optional
[
list
[
KVCacheBlock
]]
=
None
new_computed_blocks
:
Optional
[
list
[
KVCacheBlock
]]
=
None
,
num_lookahead_tokens
:
int
=
0
,
)
->
Optional
[
list
[
KVCacheBlock
]]:
)
->
Optional
[
list
[
KVCacheBlock
]]:
"""Add slots for a request with new tokens to append.
"""Add slots for a request with new tokens to append.
...
@@ -174,6 +175,9 @@ class KVCacheManager:
...
@@ -174,6 +175,9 @@ class KVCacheManager:
not include the tokens that have already been computed.
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
Blocks layout:
Blocks layout:
-----------------------------------------------------------------------
-----------------------------------------------------------------------
...
@@ -211,8 +215,9 @@ class KVCacheManager:
...
@@ -211,8 +215,9 @@ class KVCacheManager:
# the new prefix caching hits
# the new prefix caching hits
num_computed_tokens
=
(
request
.
num_computed_tokens
+
num_computed_tokens
=
(
request
.
num_computed_tokens
+
len
(
new_computed_blocks
)
*
self
.
block_size
)
len
(
new_computed_blocks
)
*
self
.
block_size
)
num_required_blocks
=
cdiv
(
num_computed_tokens
+
num_tokens
,
num_required_blocks
=
cdiv
(
self
.
block_size
)
num_computed_tokens
+
num_tokens
+
num_lookahead_tokens
,
self
.
block_size
)
num_new_blocks
=
(
num_required_blocks
-
len
(
req_blocks
)
-
num_new_blocks
=
(
num_required_blocks
-
len
(
req_blocks
)
-
len
(
new_computed_blocks
))
len
(
new_computed_blocks
))
...
@@ -246,8 +251,11 @@ class KVCacheManager:
...
@@ -246,8 +251,11 @@ class KVCacheManager:
else
:
else
:
# Get new blocks from the free block pool considering
# Get new blocks from the free block pool considering
# preallocated blocks.
# preallocated blocks.
num_preallocate_blocks
=
max
(
0
,
self
.
num_preallocate_blocks
-
num_lookahead_tokens
//
self
.
block_size
)
num_new_blocks
=
min
(
num_new_blocks
=
min
(
num_new_blocks
+
self
.
num_preallocate_blocks
,
num_new_blocks
+
num_preallocate_blocks
,
self
.
block_pool
.
get_num_free_blocks
(),
self
.
block_pool
.
get_num_free_blocks
(),
# Should not exceed the maximum number of blocks per request.
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# This is especially because the block table has the shape
...
...
vllm/v1/core/sched/scheduler.py
View file @
f49e5aff
...
@@ -7,7 +7,8 @@ from collections import deque
...
@@ -7,7 +7,8 @@ from collections import deque
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
(
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
...
@@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface):
...
@@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface):
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
structured_output_manager
:
StructuredOutputManager
,
structured_output_manager
:
StructuredOutputManager
,
speculative_config
:
SpeculativeConfig
=
None
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
include_finished_set
:
bool
=
False
,
include_finished_set
:
bool
=
False
,
log_stats
:
bool
=
False
,
log_stats
:
bool
=
False
,
...
@@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface):
...
@@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface):
self
.
encoder_cache_manager
=
EncoderCacheManager
(
self
.
encoder_cache_manager
=
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
cache_size
=
encoder_cache_size
)
self
.
num_lookahead_tokens
=
0
if
speculative_config
and
speculative_config
.
method
==
"eagle"
:
self
.
num_lookahead_tokens
=
\
speculative_config
.
num_speculative_tokens
def
schedule
(
self
)
->
SchedulerOutput
:
def
schedule
(
self
)
->
SchedulerOutput
:
# NOTE(woosuk) on the scheduling algorithm:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# There's no "decoding phase" nor "prefill phase" in the scheduler.
...
@@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface):
...
@@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface):
while
True
:
while
True
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
)
request
,
num_new_tokens
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
)
if
new_blocks
is
None
:
if
new_blocks
is
None
:
# The request cannot be scheduled.
# The request cannot be scheduled.
# Preempt the lowest-priority request.
# Preempt the lowest-priority request.
...
...
vllm/v1/engine/core.py
View file @
f49e5aff
...
@@ -98,6 +98,7 @@ class EngineCore:
...
@@ -98,6 +98,7 @@ class EngineCore:
cache_config
=
vllm_config
.
cache_config
,
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
lora_config
=
vllm_config
.
lora_config
,
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
speculative_config
=
vllm_config
.
speculative_config
,
structured_output_manager
=
self
.
structured_output_manager
,
structured_output_manager
=
self
.
structured_output_manager
,
include_finished_set
=
vllm_config
.
parallel_config
.
data_parallel_size
include_finished_set
=
vllm_config
.
parallel_config
.
data_parallel_size
>
1
,
>
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment