Unverified Commit f49e5aff authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

[V1][Spec Decode] KV cache slots for eagle heads (#16370)


Signed-off-by: default avatarLiuXiaoxuanPKU <lilyliupku@gmail.com>
parent 6c11ecf8
...@@ -7,6 +7,7 @@ from vllm.config import ModelConfig, SchedulerConfig, VllmConfig ...@@ -7,6 +7,7 @@ from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256 from vllm.utils import GiB_bytes, sha256
from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail # disable yapf here as it formats differently than isort such that both fail
# yapf: disable # yapf: disable
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType, from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
...@@ -48,6 +49,18 @@ def make_request(request_id, ...@@ -48,6 +49,18 @@ def make_request(request_id,
) )
def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)
def test_none_hash(): def test_none_hash():
assert NONE_HASH is not None assert NONE_HASH is not None
assert isinstance(NONE_HASH, int) assert isinstance(NONE_HASH, int)
...@@ -327,18 +340,6 @@ def test_metrics(): ...@@ -327,18 +340,6 @@ def test_metrics():
def test_unify_kv_cache_configs(): def test_unify_kv_cache_configs():
def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)
same_kv_cache_config = [ same_kv_cache_config = [
KVCacheConfig( KVCacheConfig(
num_blocks=10, num_blocks=10,
...@@ -470,3 +471,64 @@ def test_estimate_max_model_len(model_id, max_model_len, ...@@ -470,3 +471,64 @@ def test_estimate_max_model_len(model_id, max_model_len,
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
8 * GiB_bytes) 8 * GiB_bytes)
assert estimated_max_len == want_estimated_max_len assert estimated_max_len == want_estimated_max_len
def test_allocate_with_lookahead():
"""Verify that lookahead tokens correctly affect block allocation"""
block_size = 4
config = KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
},
kv_cache_groups=[
KVCacheGroupSpec(["layer1"],
new_kv_cache_spec(block_size=block_size)),
],
)
request = make_request(
request_id=0,
prompt_token_ids=[],
mm_positions=None,
mm_hashes=None,
)
# Test case 1: Requires additional lookahead tokens
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100,
num_preallocate_tokens=0)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
)
assert len(blocks) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100,
num_preallocate_tokens=4)
# num_preallocate_blocks = 4 // 4 - 2 // 4 = 1
# required_blocks = ceil((3 + 2) /4) = 2
# total_blocks = 1 + 2 = 3
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_lookahead_tokens=2,
)
assert len(blocks) == 3
# Test case 3: With precomputed blocks
# num_preallocate_blocks = 4 // 4 - 4 // 4 = 0
# required_blocks = ceil((3 + 4) / 4) = 2
# total_blocks = 0 + 2 = 2
kv_cache_manager = KVCacheManager(kv_cache_config=config,
max_model_len=100,
num_preallocate_tokens=4)
blocks = kv_cache_manager.allocate_slots(
request,
num_tokens=3,
num_lookahead_tokens=4,
)
assert len(blocks) == 2
...@@ -164,7 +164,8 @@ class KVCacheManager: ...@@ -164,7 +164,8 @@ class KVCacheManager:
self, self,
request: Request, request: Request,
num_tokens: int, num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None new_computed_blocks: Optional[list[KVCacheBlock]] = None,
num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]: ) -> Optional[list[KVCacheBlock]]:
"""Add slots for a request with new tokens to append. """Add slots for a request with new tokens to append.
...@@ -174,6 +175,9 @@ class KVCacheManager: ...@@ -174,6 +175,9 @@ class KVCacheManager:
not include the tokens that have already been computed. not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the new_computed_blocks: A list of new computed blocks just hitting the
prefix caching. prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
Blocks layout: Blocks layout:
----------------------------------------------------------------------- -----------------------------------------------------------------------
...@@ -211,8 +215,9 @@ class KVCacheManager: ...@@ -211,8 +215,9 @@ class KVCacheManager:
# the new prefix caching hits # the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens + num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size) len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens, num_required_blocks = cdiv(
self.block_size) num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) - num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks)) len(new_computed_blocks))
...@@ -246,8 +251,11 @@ class KVCacheManager: ...@@ -246,8 +251,11 @@ class KVCacheManager:
else: else:
# Get new blocks from the free block pool considering # Get new blocks from the free block pool considering
# preallocated blocks. # preallocated blocks.
num_preallocate_blocks = max(
0, self.num_preallocate_blocks -
num_lookahead_tokens // self.block_size)
num_new_blocks = min( num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks, num_new_blocks + num_preallocate_blocks,
self.block_pool.get_num_free_blocks(), self.block_pool.get_num_free_blocks(),
# Should not exceed the maximum number of blocks per request. # Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape # This is especially because the block table has the shape
......
...@@ -7,7 +7,8 @@ from collections import deque ...@@ -7,7 +7,8 @@ from collections import deque
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Optional, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
...@@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface): ...@@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager, structured_output_manager: StructuredOutputManager,
speculative_config: SpeculativeConfig = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False, include_finished_set: bool = False,
log_stats: bool = False, log_stats: bool = False,
...@@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface): ...@@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager = EncoderCacheManager( self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size) cache_size=encoder_cache_size)
self.num_lookahead_tokens = 0
if speculative_config and speculative_config.method == "eagle":
self.num_lookahead_tokens = \
speculative_config.num_speculative_tokens
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler. # There's no "decoding phase" nor "prefill phase" in the scheduler.
...@@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface): ...@@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface):
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens) request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
# Preempt the lowest-priority request. # Preempt the lowest-priority request.
......
...@@ -98,6 +98,7 @@ class EngineCore: ...@@ -98,6 +98,7 @@ class EngineCore:
cache_config=vllm_config.cache_config, cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config, lora_config=vllm_config.lora_config,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
speculative_config=vllm_config.speculative_config,
structured_output_manager=self.structured_output_manager, structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1, > 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment