Unverified Commit cd4a95e3 authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat][Core] Support multiple KV cache groups in Hybrid KV Coordinator (#31707)


Signed-off-by: default avatarYifan Qiao <yifanqiao@berkeley.edu>
parent d5ec6c05
...@@ -35,6 +35,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -35,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheConfig,
KVCacheGroupSpec, KVCacheGroupSpec,
MambaSpec,
SlidingWindowSpec, SlidingWindowSpec,
) )
...@@ -106,8 +107,23 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: ...@@ -106,8 +107,23 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
def make_kv_cache_config_hybrid_model( def make_kv_cache_config_hybrid_model(
block_size: int, num_blocks: int block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
) -> KVCacheConfig: ) -> KVCacheConfig:
if second_spec_type == "sliding_window":
second_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
)
elif second_spec_type == "mamba":
second_spec = MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
)
return KVCacheConfig( return KVCacheConfig(
num_blocks=num_blocks, num_blocks=num_blocks,
kv_cache_tensors=[], kv_cache_tensors=[],
...@@ -123,16 +139,49 @@ def make_kv_cache_config_hybrid_model( ...@@ -123,16 +139,49 @@ def make_kv_cache_config_hybrid_model(
), ),
KVCacheGroupSpec( KVCacheGroupSpec(
["layer2"], ["layer2"],
SlidingWindowSpec( second_spec,
),
KVCacheGroupSpec(
["layer3"],
second_spec,
),
],
)
def make_kv_cache_config_three_types(
block_size: int, num_blocks: int, third_spec_type: str = "mamba"
) -> KVCacheConfig:
if third_spec_type == "mamba":
third_spec = MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
)
elif third_spec_type == "sliding_window":
third_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4 * block_size,
)
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=1, num_kv_heads=1,
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
sliding_window=2 * block_size,
), ),
), ),
KVCacheGroupSpec( KVCacheGroupSpec(
["layer3"], ["layer2"],
SlidingWindowSpec( SlidingWindowSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=1, num_kv_heads=1,
...@@ -141,6 +190,10 @@ def make_kv_cache_config_hybrid_model( ...@@ -141,6 +190,10 @@ def make_kv_cache_config_hybrid_model(
sliding_window=2 * block_size, sliding_window=2 * block_size,
), ),
), ),
KVCacheGroupSpec(
["layer3"],
third_spec,
),
], ],
) )
...@@ -424,6 +477,184 @@ def test_prefill_hybrid_model(): ...@@ -424,6 +477,184 @@ def test_prefill_hybrid_model():
) )
def _make_hybrid_kv_cache_config(
block_size: int, num_blocks: int, spec_types: list[str]
) -> KVCacheConfig:
"""
Create a KVCacheConfig with the specified spec types.
Args:
block_size: The block size for KV cache.
num_blocks: The number of blocks in the KV cache.
spec_types: List of spec type strings. Supported types:
- "full": FullAttentionSpec
- "sliding_window": SlidingWindowSpec with window=2*block_size
- "sliding_window_large": SlidingWindowSpec with window=4*block_size
- "mamba": MambaSpec
"""
spec_map = {
"full": lambda: FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
"sliding_window": lambda: SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
"sliding_window_large": lambda: SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4 * block_size,
),
"mamba": lambda: MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
),
}
kv_cache_groups = [
KVCacheGroupSpec([f"layer{i}"], spec_map[spec_type]())
for i, spec_type in enumerate(spec_types)
]
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=kv_cache_groups,
)
# Test cases covering various combinations of KV cache spec types:
# - Varying number of groups (2, 3, or 4)
# - 0, 1, or 2 full attention groups
# - Sliding window with different window sizes
# - Interleaved group IDs (full attn and other types mixed)
# - Mamba spec combinations
_HYBRID_MODEL_TEST_CASES = [
# 2 groups: 1 full + 1 other
pytest.param(["full", "sliding_window"], id="2g-full+sw"),
pytest.param(["full", "mamba"], id="2g-full+mamba"),
# 2 groups: 0 full (all other types)
pytest.param(["sliding_window", "mamba"], id="2g-sw+mamba"),
pytest.param(["sliding_window", "sliding_window_large"], id="2g-sw+sw_large"),
# 3 groups: 1 full + 2 others (same type)
pytest.param(["full", "sliding_window", "sliding_window"], id="3g-full+2sw"),
pytest.param(["full", "mamba", "mamba"], id="3g-full+2mamba"),
# 3 groups: 1 full + 2 others (different types)
pytest.param(["full", "sliding_window", "mamba"], id="3g-full+sw+mamba"),
pytest.param(
["full", "sliding_window", "sliding_window_large"],
id="3g-full+sw+sw_large",
),
# 3 groups: 2 full + 1 other
pytest.param(["full", "full", "sliding_window"], id="3g-2full+sw"),
pytest.param(["full", "full", "mamba"], id="3g-2full+mamba"),
# 4 groups: interleaved (full, other, full, other)
pytest.param(
["full", "sliding_window", "full", "sliding_window_large"],
id="4g-interleaved-full+sw+sw_large",
),
pytest.param(
["full", "mamba", "full", "mamba"],
id="4g-interleaved-full+mamba",
),
# 4 groups: interleaved with different sliding windows
pytest.param(
["full", "sliding_window", "full", "sliding_window_large"],
id="4g-interleaved-full+sw_mixed",
),
# 4 groups: 0 full (all other types)
pytest.param(
["sliding_window", "mamba", "sliding_window_large", "mamba"],
id="4g-sw+mamba+sw_large+mamba",
),
# 4 groups: 2 full + 2 others (grouped)
pytest.param(
["full", "full", "sliding_window", "mamba"],
id="4g-2full+sw+mamba",
),
]
@pytest.mark.parametrize("spec_types", _HYBRID_MODEL_TEST_CASES)
def test_prefill_hybrid_model_combinations(spec_types: list[str]):
"""
Test prefix caching with hybrid models containing various combinations of
KV cache spec types.
This unified test covers:
- Various combinations (full attn + other attn types)
- Varying number of groups (2, 3, or 4)
- 0, 1, or 2 full attention groups in the combination
- Two sliding_window attn groups with different window sizes
- Interleaved group IDs (full attn and other types alternating)
- Mamba spec with other attention types
"""
block_size = 16
num_groups = len(spec_types)
# Allocate enough blocks for all groups
num_blocks = 10 * num_groups
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
hash_fn = sha256
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
# First request: no cache hit initially
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.block_hashes) == 3
assert not computed_blocks.blocks[0] # No cache hit initially
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks
)
assert blocks is not None
# Should have blocks for all groups
assert len(blocks.get_block_ids()) == num_groups
# Second request: should hit cached blocks for common prefix
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should hit cached blocks for all groups
assert num_computed_tokens == 3 * block_size
assert len(computed_blocks.blocks) == num_groups
# Allocate and verify blocks for second request
blocks = manager.allocate_slots(
req1,
len(common_token_ids) + 5 - num_computed_tokens,
num_computed_tokens,
computed_blocks,
)
assert blocks is not None
assert len(blocks.get_block_ids()) == num_groups
manager.free(req0)
manager.free(req1)
def test_prefill_plp(): def test_prefill_plp():
"""Test prefill with APC and some prompt logprobs (plp) requests. """Test prefill with APC and some prompt logprobs (plp) requests.
......
...@@ -14,7 +14,7 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -14,7 +14,7 @@ from vllm.v1.core.kv_cache_utils import (
) )
from vllm.v1.core.single_type_kv_cache_manager import ( from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager, CrossAttentionManager,
FullAttentionManager, SingleTypeKVCacheManager,
get_manager_for_kv_cache_spec, get_manager_for_kv_cache_spec,
) )
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
...@@ -354,9 +354,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ...@@ -354,9 +354,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
""" """
KV cache coordinator for hybrid models with multiple KV cache types, and KV cache coordinator for hybrid models with multiple KV cache types, and
thus multiple kv cache groups. thus multiple kv cache groups.
To simplify `find_longest_cache_hit`, it only supports the combination of
two types of KV cache groups, and one of them must be full attention.
May extend to more general cases in the future.
""" """
def __init__( def __init__(
...@@ -397,70 +394,46 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ...@@ -397,70 +394,46 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def verify_and_split_kv_cache_groups(self) -> None: def verify_and_split_kv_cache_groups(self) -> None:
""" """
Verifies that the model has exactly two types of KV cache groups, and Groups KV cache groups by their spec type for efficient batch processing
one of them is full attention. Then, split the kv cache groups into full during cache hit lookup.
attention groups and other groups.
""" """
full_attention_spec: FullAttentionSpec | None = None attention_groups: list[
other_spec: KVCacheSpec | None = None tuple[KVCacheSpec, list[int], type[SingleTypeKVCacheManager]]
self.full_attention_group_ids: list[int] = [] ] = []
self.other_group_ids: list[int] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups): for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec): manager_cls = self.single_type_managers[i].__class__
if full_attention_spec is None: spec = g.kv_cache_spec
full_attention_spec = g.kv_cache_spec
else: # Try to find an existing group with the same spec
assert full_attention_spec == g.kv_cache_spec, ( for existing_spec, group_ids, existing_cls in attention_groups:
"HybridKVCacheCoordinator assumes exactly one type of " if existing_spec == spec:
"full attention groups now." assert manager_cls is existing_cls, (
"Expected same manager class for identical KV cache specs."
) )
self.full_attention_group_ids.append(i) group_ids.append(i)
break
else: else:
if other_spec is None: attention_groups.append((spec, [i], manager_cls))
other_spec = g.kv_cache_spec
else:
assert other_spec == g.kv_cache_spec, (
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now."
)
self.other_group_ids.append(i)
assert full_attention_spec is not None, ( assert len(attention_groups) > 1, (
"HybridKVCacheCoordinator assumes exactly one type of full " "HybridKVCacheCoordinator requires at least two attention groups."
"attention groups now."
) )
assert other_spec is not None, (
"HybridKVCacheCoordinator assumes exactly one type of other groups now." # Put full attention first: its efficient left-to-right scan provides
# a tighter initial bound, reducing work for subsequent groups.
self.attention_groups = sorted(
attention_groups,
key=lambda x: not isinstance(x[0], FullAttentionSpec),
) )
self.full_attention_manager_cls = FullAttentionManager # The LCM of the block sizes of all attention types.
self.other_attention_cls = self.single_type_managers[
self.other_group_ids[0]
].__class__
self.full_attention_spec = full_attention_spec
self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
# The LCM of the block sizes of full attention and other attention.
# The cache hit length must be a multiple of the LCM of the block sizes # The cache hit length must be a multiple of the LCM of the block sizes
# to make sure the cache hit length is a multiple of the block size of # to make sure the cache hit length is a multiple of the block size of
# each attention type. Requiring this because we don't support partial # each attention type. Requiring this because we don't support partial
# block cache hit yet. # block cache hit yet.
self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size) block_sizes = [spec.block_size for spec, _, _ in attention_groups]
self.lcm_block_size = lcm(*block_sizes)
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
self.full_attn_first = False
else:
raise ValueError(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks."
)
def find_longest_cache_hit( def find_longest_cache_hit(
self, self,
...@@ -468,7 +441,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ...@@ -468,7 +441,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
max_cache_hit_length: int, max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]: ) -> tuple[tuple[list[KVCacheBlock], ...], int]:
""" """
Find the longest cache hit for the request. Find the longest cache hit using an iterative fixed-point algorithm.
Each attention type either accepts the current candidate length or
reduces it. If any type reduces the length, restart checks over all
types. This converges because length monotonically decreases and is
bounded below by 0.
Args: Args:
block_hashes: The block hashes of the request. block_hashes: The block hashes of the request.
...@@ -476,75 +454,63 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ...@@ -476,75 +454,63 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
Returns: Returns:
A tuple containing: A tuple containing:
- A list of the cache hit blocks for each single type manager. - A tuple of the cache hit blocks for each single type manager.
- The number of tokens of the longest cache hit. - The number of tokens of the longest cache hit.
""" """
# First, find the longest cache hit for full attention.
if self.full_attention_spec.block_size == self.hash_block_size: def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList:
# Common case. if kv_cache_spec.block_size == self.hash_block_size:
full_attention_block_hashes: BlockHashList = block_hashes return block_hashes
else: return BlockHashListWithBlockSize(
# block_size is a multiple of hash_block_size. This happens when different block_hashes, self.hash_block_size, kv_cache_spec.block_size
# KV cache groups have different block sizes. In this case, we need to
# recalculate block_hashes at the granularity of block_size, using the
# original block_hashes (at the granularity of hash_block_size).
full_attention_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
)
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=full_attention_block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
if self.other_spec.block_size == self.hash_block_size:
# Common case.
other_block_hashes: BlockHashList = block_hashes
else:
# Similar to the full attention case, here we need to recalculate
# block_hashes at the granularity of block_size, using the original
# block_hashes (at the granularity of hash_block_size).
other_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.other_spec.block_size
) )
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
block_hashes=other_block_hashes, num_groups = len(self.kv_cache_config.kv_cache_groups)
max_length=hit_length, hit_length = max_cache_hit_length
kv_cache_group_ids=self.other_group_ids, hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups
block_pool=self.block_pool,
kv_cache_spec=self.other_spec, while True:
use_eagle=self.use_eagle, curr_hit_length = hit_length
alignment_tokens=self.lcm_block_size,
) for spec, group_ids, manager_cls in self.attention_groups:
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size is_full_attn = isinstance(spec, FullAttentionSpec)
# NOTE: the prefix cache hit length must be a multiple of block_size as # Full attention: reuse cached blocks (downward-closed property)
# we don't support partial block cache hit yet. The cache hit length cached_blocks = hit_blocks_by_group[group_ids[0]]
# of other attention is ensured to be a multiple of the block size of if is_full_attn and cached_blocks is not None:
# full attention layers in current implementation, because hit_length is # For full attention, we only need to compute the cache hit
# a multiple of other attention's block size, and other attention's # length once. Starting from the second iteration, if the
# block size is a multiple of full attention's block size (verified in # curr_hit_length is reduced by other groups, we can simply
# `verify_and_split_kv_cache_groups`). # keep the first (curr_hit_length // block_size) blocks from
assert hit_length % self.full_attention_block_size == 0 # the last iteration.
num_blocks = curr_hit_length // spec.block_size
# Truncate the full attention cache hit to the length of the curr_hit_length = num_blocks * spec.block_size
# cache hit of the other attention. for group_id in group_ids:
for group_hit_blocks in hit_blocks_full_attn: blocks = hit_blocks_by_group[group_id]
del group_hit_blocks[hit_length // self.full_attention_block_size :] assert blocks is not None
del blocks[num_blocks:]
# Merge the hit blocks of full attention and other attention. else:
if self.full_attn_first: hit_blocks = manager_cls.find_longest_cache_hit(
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn block_hashes=_get_block_hashes(spec),
else: max_length=curr_hit_length,
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn kv_cache_group_ids=group_ids,
return hit_blocks, hit_length block_pool=self.block_pool,
kv_cache_spec=spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
curr_hit_length = len(hit_blocks[0]) * spec.block_size
for group_id, blocks in zip(group_ids, hit_blocks):
hit_blocks_by_group[group_id] = blocks
if curr_hit_length < hit_length:
hit_length = curr_hit_length
else:
break
return tuple(
blocks if blocks is not None else [] for blocks in hit_blocks_by_group
), hit_length
def get_kv_cache_coordinator( def get_kv_cache_coordinator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment