Unverified Commit 6550114c authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[v1] Redo "Support multiple KV cache groups in GPU model runner (#17945)" (#18593)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent 9520a989
......@@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (FreeKVCacheBlockQueue, KVCacheBlock,
hash_request_tokens,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
......@@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False):
use_mla=False,
sliding_window=None):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla)
use_mla=use_mla,
sliding_window=sliding_window)
def test_none_hash(monkeypatch):
......@@ -492,6 +495,68 @@ def test_unify_kv_cache_configs():
unify_kv_cache_configs(diff_kv_cache_config)
def test_merge_kv_cache_spec():
same_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32),
]
merged_layer_spec = same_layer_specs[0].merge(same_layer_specs)
assert merged_layer_spec.block_size == 16
assert merged_layer_spec.num_kv_heads == 32
assert merged_layer_spec.head_size == 64
assert merged_layer_spec.dtype == torch.float32
assert merged_layer_spec.sliding_window is None
different_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=16),
]
with pytest.raises(AssertionError):
different_layer_specs[0].merge(different_layer_specs)
full_spec = new_kv_cache_spec(num_kv_heads=32)
different_type_layer_specs = [
full_spec,
SlidingWindowSpec(
block_size=full_spec.block_size,
num_kv_heads=full_spec.num_kv_heads,
head_size=full_spec.head_size,
dtype=full_spec.dtype,
use_mla=full_spec.use_mla,
sliding_window=1,
),
]
with pytest.raises(AssertionError):
different_type_layer_specs[0].merge(different_type_layer_specs)
with pytest.raises(AssertionError):
different_type_layer_specs[1].merge(different_type_layer_specs)
different_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=2),
]
with pytest.raises(ValueError):
different_sliding_window_layer_specs[0].merge(
different_sliding_window_layer_specs)
same_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
]
merged_layer_spec = same_sliding_window_layer_specs[0].merge(
same_sliding_window_layer_specs)
assert merged_layer_spec.sliding_window == 1
same_sliding_window_layer_spec_with_none = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=None),
]
merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge(
same_sliding_window_layer_spec_with_none)
assert merged_layer_spec.sliding_window == 1
@pytest.mark.parametrize(
("model_id", "max_model_len", "want_estimated_max_len"), [
("Qwen/Qwen1.5-7B", 16385, 16384),
......
......@@ -84,7 +84,7 @@ def test_prefill(hash_algo):
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
# Check full block metadata
parent_block_hash = None
......@@ -107,13 +107,13 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
......@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [6]
assert blocks.get_block_ids() == [[6]]
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
......@@ -171,7 +171,7 @@ def test_prefill(hash_algo):
len(computed_blocks.blocks) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
......@@ -208,7 +208,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0_block_hashes = [b.block_hash for b in blocks.blocks]
# Check full block metadata
......@@ -233,13 +233,13 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
......@@ -277,11 +277,11 @@ def test_prefill_plp():
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4]
assert block_ids != [[1, 2, 3, 4]]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for block_id in block_ids:
for block_id in block_ids[0]:
assert manager.block_pool.blocks[block_id].ref_cnt == 1
manager.free(req2)
......@@ -307,7 +307,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
......@@ -379,12 +379,12 @@ def test_evict():
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [1, 2]
assert computed_blocks.get_block_ids() == [[1, 2]]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [10]
assert blocks.get_block_ids() == [[10]]
assert manager.block_pool.free_block_queue.num_free_blocks == 7
......@@ -625,7 +625,7 @@ def test_mm_prefix_caching():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
......@@ -686,7 +686,7 @@ def test_cache_key_salting():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
......@@ -797,7 +797,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == [1, 2, 3, 4]
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
......@@ -808,7 +808,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
assert blocks.get_block_ids() == [[5]]
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
......
......@@ -9,9 +9,11 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
InputBatch)
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
......@@ -22,6 +24,27 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS = 64
def get_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=1,
num_kv_heads=1,
head_size=16,
dtype=torch.float16,
use_mla=False,
),
),
],
)
def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
......@@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
......@@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[],
block_ids=[[]],
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
......@@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
......@@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
)
reqs: list[CachedRequestState] = []
......
# SPDX-License-Identifier: Apache-2.0
import weakref
import pytest
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
......@@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner):
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
"""
kv_cache_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16,
use_mla=False)
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)
kv_cache_config = KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=16,
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
use_mla=False,
))
])
runner.kv_cache_config = kv_cache_config
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
max_model_len=runner.max_model_len,
max_num_batched_tokens=runner.max_num_tokens,
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
)
runner.initialize_attn_backend(kv_cache_config)
@pytest.fixture
......@@ -48,10 +70,12 @@ def model_runner():
swap_space=0,
cache_dtype="auto",
)
parallel_config = ParallelConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)
device = "cuda"
......@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[0],
block_ids=[[0]],
num_computed_tokens=0,
lora_request=None,
))
......@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table
block_table = model_runner.input_batch.block_table[0]
req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
if block_table.num_blocks_per_row[req_index] != len(
req_state.block_ids[0]):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
req_state.block_ids).all()
req_state.block_ids[0]).all()
def test_update_states_new_request(model_runner):
......@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[],
new_block_ids=[[]],
num_computed_tokens=0,
)
......
......@@ -288,7 +288,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in self._requests_need_load:
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False)
total_need_load += 1
......@@ -299,7 +299,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# the original prompt tokens.
if not self._found_match_for_request(new_req):
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True)
......@@ -319,7 +319,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = cached_req.new_block_ids
block_ids = cached_req.new_block_ids[0]
meta.add_request(token_ids=token_ids,
block_ids=block_ids,
......
......@@ -69,13 +69,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_model_len = self.runner.model_config.max_model_len
assert max_model_len == 32768,\
"AITER MLA requires max_model_len=32768"
assert self.runner.block_size == 1, "AITER MLA" \
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."
def _get_paged_kv_tensors(
self, block_table: torch.Tensor,
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
page_size = self.runner.block_size
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device
......
......@@ -32,9 +32,16 @@ class KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return cls([])
def get_block_ids(self) -> list[int]:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks]
def get_block_ids(self) -> list[list[int]]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups (only 1 group now)
* each inner list contains the block_ids of the blocks in that group
"""
return [[block.block_id for block in self.blocks]]
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
......@@ -300,9 +307,9 @@ class KVCacheManager:
self,
request: Request,
num_running_requests: int,
) -> int:
) -> list[int]:
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state.
in the RUNNING state for each kv cache group.
The function determines this by selecting any request and iterating
through its blocks. A block is considered a common prefix block if its
......@@ -332,11 +339,14 @@ class KVCacheManager:
requests in the current step.
Returns:
int: The number of common prefix blocks.
list[int]: The number of common prefix blocks for each kv cache
group.
"""
assert request.status == RequestStatus.RUNNING
return self.single_type_manager.get_num_common_prefix_blocks(
return [
self.single_type_manager.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
]
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
......@@ -354,10 +364,8 @@ class KVCacheManager:
"""
return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[int]:
def get_block_ids(self, request_id: str) -> list[list[int]]:
"""Get the block ids of a request."""
assert request_id in self.single_type_manager.req_to_blocks
return [
block.block_id
for block in self.single_type_manager.req_to_blocks[request_id]
]
return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id]
).get_block_ids()
......@@ -577,14 +577,12 @@ def create_kv_cache_group_specs(
"""
kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_spec = kv_cache_spec[layer_names_one_group[0]]
assert all(
kv_cache_spec[layer_name] == layer_spec
for layer_name in layer_names_one_group[1:]), (
"All layers in the same KV cache group must share the same "
"KVCacheSpec.")
layer_specs = [
kv_cache_spec[layer_name] for layer_name in layer_names_one_group
]
merged_layer_spec = layer_specs[0].merge(layer_specs)
kv_cache_groups.append(
KVCacheGroupSpec(layer_names_one_group, layer_spec))
KVCacheGroupSpec(layer_names_one_group, merged_layer_spec))
return kv_cache_groups
......@@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
)
......
......@@ -26,7 +26,7 @@ class NewRequestData:
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
block_ids: list[int]
block_ids: list[list[int]]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
......@@ -34,7 +34,7 @@ class NewRequestData:
def from_request(
cls,
request: Request,
block_ids: list[int],
block_ids: list[list[int]],
) -> NewRequestData:
return cls(
req_id=request.request_id,
......@@ -85,7 +85,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: list[int]
new_block_ids: list[int]
new_block_ids: list[list[int]]
num_computed_tokens: int
@classmethod
......@@ -94,7 +94,7 @@ class CachedRequestData:
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: list[int],
new_block_ids: list[list[int]],
) -> CachedRequestData:
return cls(
req_id=request.request_id,
......@@ -131,9 +131,9 @@ class SchedulerOutput:
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: dict[str, list[int]]
# Number of common prefix blocks for all requests.
# Number of common prefix blocks for all requests in each KV cache group.
# This can be used for cascade attention.
num_common_prefix_blocks: int
num_common_prefix_blocks: list[int]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
......
......@@ -173,7 +173,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, list[int]] = {}
req_to_new_block_ids: dict[str, list[list[int]]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
......@@ -486,7 +486,8 @@ class Scheduler(SchedulerInterface):
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = 0
num_common_prefix_blocks = [0] * len(
self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
......@@ -573,7 +574,7 @@ class Scheduler(SchedulerInterface):
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: list[int],
new_block_ids: list[list[int]],
resumed_from_preemption: bool,
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
......@@ -949,7 +950,9 @@ class Scheduler(SchedulerInterface):
"""
if self.connector is None:
return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
assert len(self.kv_cache_config.kv_cache_groups
) == 1, "KV connector only supports one KV cache group now"
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]
return self.connector.request_finished(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
......@@ -966,9 +969,10 @@ class Scheduler(SchedulerInterface):
"""
if request.request_id not in self.finished_recving_kv_req_ids:
return False
assert len(self.kv_cache_config.kv_cache_groups
) == 1, "KV connector only supports one KV cache group now"
# Now that the blocks are ready, actually cache them.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]
num_computed_tokens = len(block_ids) * self.block_size
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
......
# SPDX-License-Identifier: Apache-2.0
import copy
from dataclasses import dataclass
from typing import Optional
import torch
from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.logger import init_logger
......@@ -53,6 +56,16 @@ class KVCacheSpec:
"""
raise NotImplementedError
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
"All layers in the same KV cache group must share the same "
"type_id.")
return copy.deepcopy(specs[0])
@dataclass
class AttentionSpec(KVCacheSpec):
......@@ -71,6 +84,16 @@ class AttentionSpec(KVCacheSpec):
@dataclass
class FullAttentionSpec(AttentionSpec):
sliding_window: Optional[int] = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
@property
def type_id(self) -> str:
......@@ -80,6 +103,25 @@ class FullAttentionSpec(AttentionSpec):
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec = super().merge(specs)
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
if len(sliding_window) == 0:
merged_spec.sliding_window = None
elif len(sliding_window) == 1:
merged_spec.sliding_window = sliding_window.pop()
else:
raise ValueError(
"All sliding window layers in the same KV cache group "
"must have the same window size.")
return merged_spec
@dataclass
class SlidingWindowSpec(AttentionSpec):
......
......@@ -4,6 +4,7 @@ import numpy as np
import torch
from vllm.logger import init_logger
from vllm.utils import cdiv
logger = init_logger(__name__)
......@@ -96,3 +97,43 @@ class BlockTable:
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_size: int) -> None:
self.block_tables = [
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
max_num_batched_tokens, pin_memory, device)
]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def commit(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit(num_reqs)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]
......@@ -14,7 +14,7 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.block_table import MultiGroupBlockTable
_SAMPLING_EPS = 1e-5
......@@ -29,7 +29,7 @@ class CachedRequestState:
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: list[int]
block_ids: list[list[int]]
num_computed_tokens: int
output_token_ids: list[int]
......@@ -58,15 +58,14 @@ class InputBatch:
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_size: int,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
......@@ -99,12 +98,13 @@ class InputBatch:
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = BlockTable(
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_size=block_size,
)
# Sampling-related.
......
This diff is collapsed.
......@@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# self.input_batch: InputBatch # Persistent batch.
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.vocab_size,
)
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
......@@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.block_table_cpu = torch.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
dtype=torch.int32,
device="cpu")
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
......@@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.input_batch.block_table.
out=self.input_batch.block_table[0].
slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
......@@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.position_ids = self.positions_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
self.input_batch.block_table.slot_mapping_cpu[
self.input_batch.block_table[0].slot_mapping_cpu[
total_num_scheduled_tokens:] = _PAD_SLOT_ID
slot_mapping = (
self.input_batch.block_table.
self.input_batch.block_table[0].
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
self.device))
block_tables = self.block_table_cpu[:self.max_num_reqs]
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
block_tables = block_tables.to(self.device)
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
self.device)
......@@ -1263,6 +1254,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"supported yet.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
block_size,
)
assert self.block_table_cpu.dtype == self.input_batch.block_table[
0].get_cpu_tensor().dtype
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment