Unverified Commit e60f550b authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[v1] Support multiple KV cache groups in GPU model runner (#17945)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent f25e0d11
...@@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType, ...@@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
hash_request_tokens, hash_request_tokens,
unify_kv_cache_configs) unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor) KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16, ...@@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads=2, num_kv_heads=2,
head_size=64, head_size=64,
dtype=torch.float32, dtype=torch.float32,
use_mla=False): use_mla=False,
sliding_window=None):
return FullAttentionSpec(block_size=block_size, return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
use_mla=use_mla) use_mla=use_mla,
sliding_window=sliding_window)
def test_none_hash(): def test_none_hash():
...@@ -471,6 +474,68 @@ def test_unify_kv_cache_configs(): ...@@ -471,6 +474,68 @@ def test_unify_kv_cache_configs():
unify_kv_cache_configs(diff_kv_cache_config) unify_kv_cache_configs(diff_kv_cache_config)
def test_merge_kv_cache_spec():
same_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32),
]
merged_layer_spec = same_layer_specs[0].merge(same_layer_specs)
assert merged_layer_spec.block_size == 16
assert merged_layer_spec.num_kv_heads == 32
assert merged_layer_spec.head_size == 64
assert merged_layer_spec.dtype == torch.float32
assert merged_layer_spec.sliding_window is None
different_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=16),
]
with pytest.raises(AssertionError):
different_layer_specs[0].merge(different_layer_specs)
full_spec = new_kv_cache_spec(num_kv_heads=32)
different_type_layer_specs = [
full_spec,
SlidingWindowSpec(
block_size=full_spec.block_size,
num_kv_heads=full_spec.num_kv_heads,
head_size=full_spec.head_size,
dtype=full_spec.dtype,
use_mla=full_spec.use_mla,
sliding_window=1,
),
]
with pytest.raises(AssertionError):
different_type_layer_specs[0].merge(different_type_layer_specs)
with pytest.raises(AssertionError):
different_type_layer_specs[1].merge(different_type_layer_specs)
different_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=2),
]
with pytest.raises(ValueError):
different_sliding_window_layer_specs[0].merge(
different_sliding_window_layer_specs)
same_sliding_window_layer_specs = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
]
merged_layer_spec = same_sliding_window_layer_specs[0].merge(
same_sliding_window_layer_specs)
assert merged_layer_spec.sliding_window == 1
same_sliding_window_layer_spec_with_none = [
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
new_kv_cache_spec(num_kv_heads=32, sliding_window=None),
]
merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge(
same_sliding_window_layer_spec_with_none)
assert merged_layer_spec.sliding_window == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "max_model_len", "want_estimated_max_len"), [ ("model_id", "max_model_len", "want_estimated_max_len"), [
("Qwen/Qwen1.5-7B", 16385, 16384), ("Qwen/Qwen1.5-7B", 16385, 16384),
......
...@@ -84,7 +84,7 @@ def test_prefill(hash_algo): ...@@ -84,7 +84,7 @@ def test_prefill(hash_algo):
blocks = manager.allocate_slots(req0, 55, blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [[1, 2, 3, 4]]
# Check full block metadata # Check full block metadata
parent_block_hash = None parent_block_hash = None
...@@ -107,13 +107,13 @@ def test_prefill(hash_algo): ...@@ -107,13 +107,13 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids) req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3] assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [5] assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks: for block in computed_blocks.blocks:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
...@@ -141,13 +141,13 @@ def test_prefill(hash_algo): ...@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids) req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3] assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [6] assert blocks.get_block_ids() == [[6]]
# Although we only have 6 free blocks, we have 8 blocks in # Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal. # the free block queue due to lazy removal.
...@@ -171,7 +171,7 @@ def test_prefill(hash_algo): ...@@ -171,7 +171,7 @@ def test_prefill(hash_algo):
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
# This block ID order also checks the eviction order. # This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None assert manager.block_pool.free_block_queue.free_list_tail is None
...@@ -208,7 +208,7 @@ def test_prefill_plp(): ...@@ -208,7 +208,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55, blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0_block_hashes = [b.block_hash for b in blocks.blocks] req0_block_hashes = [b.block_hash for b in blocks.blocks]
# Check full block metadata # Check full block metadata
...@@ -233,13 +233,13 @@ def test_prefill_plp(): ...@@ -233,13 +233,13 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids) req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [1, 2, 3] assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [5] assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks: for block in computed_blocks.blocks:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
...@@ -277,11 +277,11 @@ def test_prefill_plp(): ...@@ -277,11 +277,11 @@ def test_prefill_plp():
block_ids = blocks.get_block_ids() block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0 # Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4] assert block_ids != [[1, 2, 3, 4]]
# Request #2 block hashes are valid since request #0 hashes are. # Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts. # Check block reference counts.
for block_id in block_ids: for block_id in block_ids[0]:
assert manager.block_pool.blocks[block_id].ref_cnt == 1 assert manager.block_pool.blocks[block_id].ref_cnt == 1
manager.free(req2) manager.free(req2)
...@@ -307,7 +307,7 @@ def test_decode(): ...@@ -307,7 +307,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55, blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [[1, 2, 3, 4]]
# Append slots without allocating a new block. # Append slots without allocating a new block.
req0.num_computed_tokens = 55 req0.num_computed_tokens = 55
...@@ -379,12 +379,12 @@ def test_evict(): ...@@ -379,12 +379,12 @@ def test_evict():
# Touch the first 2 blocks. # Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3))) req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [1, 2] assert computed_blocks.get_block_ids() == [[1, 2]]
assert num_computed_tokens == 2 * 16 assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [10] assert blocks.get_block_ids() == [[10]]
assert manager.block_pool.free_block_queue.num_free_blocks == 7 assert manager.block_pool.free_block_queue.num_free_blocks == 7
...@@ -625,7 +625,7 @@ def test_mm_prefix_caching(): ...@@ -625,7 +625,7 @@ def test_mm_prefix_caching():
blocks = manager.allocate_slots(req0, 59, blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# Append slots without allocating a new block. # Append slots without allocating a new block.
...@@ -686,7 +686,7 @@ def test_cache_key_salting(): ...@@ -686,7 +686,7 @@ def test_cache_key_salting():
blocks = manager.allocate_slots(req0, 59, blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# Append slots without allocating a new block. # Append slots without allocating a new block.
...@@ -797,7 +797,7 @@ def test_reset_prefix_cache(): ...@@ -797,7 +797,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids) req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55) blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [[1, 2, 3, 4]]
unique_token_ids = [4] * 7 unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
...@@ -808,7 +808,7 @@ def test_reset_prefix_cache(): ...@@ -808,7 +808,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7, blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16, len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [5] assert blocks.get_block_ids() == [[5]]
# Failed to reset prefix cache because some blocks are not freed yet. # Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache() assert not manager.reset_prefix_cache()
......
...@@ -9,9 +9,11 @@ import torch ...@@ -9,9 +9,11 @@ import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState, from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
InputBatch) from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
VOCAB_SIZE = 1024 VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20 NUM_OUTPUT_TOKENS = 20
...@@ -22,6 +24,27 @@ CUDA_DEVICES = [ ...@@ -22,6 +24,27 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS = 64 MAX_NUM_PROMPT_TOKENS = 64
def get_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=1,
num_kv_heads=1,
head_size=16,
dtype=torch.float16,
use_mla=False,
),
),
],
)
def _compare_objs(obj1, obj2): def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([ attr_names = set([
...@@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2): ...@@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
elif isinstance(a, np.ndarray): elif isinstance(a, np.ndarray):
if np.allclose(a, b): if np.allclose(a, b):
is_same = True is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)): elif isinstance(a, (BlockTable, SamplingMetadata)):
_compare_objs(a, b) _compare_objs(a, b)
is_same = True # if we make it here must be same is_same = True # if we make it here must be same
...@@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params=_create_sampling_params(), sampling_params=_create_sampling_params(),
mm_inputs=[], mm_inputs=[],
mm_positions=[], mm_positions=[],
block_ids=[], block_ids=[[]],
generator=None, generator=None,
num_computed_tokens=len(output_token_ids), num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
...@@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): ...@@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch: InputBatch = InputBatch( input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size, max_num_reqs=batch_size,
max_model_len=1024, max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
device=torch.device(device), device=torch.device(device),
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024, vocab_size=1024,
kv_cache_config=get_kv_cache_config(),
) )
reqs: list[CachedRequestState] = [] reqs: list[CachedRequestState] = []
req_id_reqs = {} req_id_reqs = {}
...@@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, ...@@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
input_batch: InputBatch = InputBatch( input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size, max_num_reqs=batch_size,
max_model_len=1024, max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
device=torch.device(device), device=torch.device(device),
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024, vocab_size=1024,
kv_cache_config=get_kv_cache_config(),
) )
ref_input_batch: InputBatch = InputBatch( ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size, max_num_reqs=batch_size,
max_model_len=1024, max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
device=torch.device(device), device=torch.device(device),
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024, vocab_size=1024,
kv_cache_config=get_kv_cache_config(),
) )
reqs: list[CachedRequestState] = [] reqs: list[CachedRequestState] = []
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import weakref
import pytest import pytest
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput) SchedulerOutput)
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
...@@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner): ...@@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner):
""" """
Only perform necessary steps in GPUModelRunner.initialize_kv_cache() Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
""" """
kv_cache_spec = FullAttentionSpec(block_size=16, kv_cache_config = KVCacheConfig(
num_kv_heads=1, num_blocks=10,
head_size=64, tensors={
dtype=torch.float16, "layer.0": KVCacheTensor(size=1024),
use_mla=False) },
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()( kv_cache_groups=[
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table) KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=16,
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
use_mla=False,
))
])
runner.kv_cache_config = kv_cache_config
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
max_model_len=runner.max_model_len,
max_num_batched_tokens=runner.max_num_tokens,
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config,
)
runner.initialize_attn_backend(kv_cache_config)
@pytest.fixture @pytest.fixture
...@@ -48,10 +70,12 @@ def model_runner(): ...@@ -48,10 +70,12 @@ def model_runner():
swap_space=0, swap_space=0,
cache_dtype="auto", cache_dtype="auto",
) )
parallel_config = ParallelConfig()
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
parallel_config=parallel_config,
) )
device = "cuda" device = "cuda"
...@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
block_ids=[0], block_ids=[[0]],
num_computed_tokens=0, num_computed_tokens=0,
lora_request=None, lora_request=None,
)) ))
...@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner, ...@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_index = model_runner.input_batch.req_id_to_index[req_id] req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table block_table = model_runner.input_batch.block_table[0]
req_state = model_runner.requests[req_id] req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): if block_table.num_blocks_per_row[req_index] != len(
req_state.block_ids[0]):
return False return False
num_blocks = block_table.num_blocks_per_row[req_index] num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] == return (block_table.block_table_np[req_index, :num_blocks] ==
req_state.block_ids).all() req_state.block_ids[0]).all()
def test_update_states_new_request(model_runner): def test_update_states_new_request(model_runner):
...@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner): ...@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id, req_id=req_id,
resumed_from_preemption=False, resumed_from_preemption=False,
new_token_ids=[], new_token_ids=[],
new_block_ids=[], new_block_ids=[[]],
num_computed_tokens=0, num_computed_tokens=0,
) )
......
...@@ -2,7 +2,7 @@ gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main ...@@ -2,7 +2,7 @@ gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True #gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq, TheBloke/Llama-2-7B-GPTQ, main gptq, TheBloke/Llama-2-7B-GPTQ, main
......
...@@ -288,7 +288,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -288,7 +288,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
for new_req in scheduler_output.scheduled_new_reqs: for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in self._requests_need_load: if new_req.req_id in self._requests_need_load:
meta.add_request(token_ids=new_req.prompt_token_ids, meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids, block_ids=new_req.block_ids[0],
block_size=self._block_size, block_size=self._block_size,
is_store=False) is_store=False)
total_need_load += 1 total_need_load += 1
...@@ -299,7 +299,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -299,7 +299,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# the original prompt tokens. # the original prompt tokens.
if not self._found_match_for_request(new_req): if not self._found_match_for_request(new_req):
meta.add_request(token_ids=new_req.prompt_token_ids, meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids, block_ids=new_req.block_ids[0],
block_size=self._block_size, block_size=self._block_size,
is_store=True) is_store=True)
...@@ -319,7 +319,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -319,7 +319,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# NOTE(rob): For resumed req, new_block_ids is all # NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request. # of the block_ids for the request.
block_ids = cached_req.new_block_ids block_ids = cached_req.new_block_ids[0]
meta.add_request(token_ids=token_ids, meta.add_request(token_ids=token_ids,
block_ids=block_ids, block_ids=block_ids,
......
...@@ -67,13 +67,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -67,13 +67,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_model_len = self.runner.model_config.max_model_len max_model_len = self.runner.model_config.max_model_len
assert max_model_len == 32768,\ assert max_model_len == 32768,\
"AITER MLA requires max_model_len=32768" "AITER MLA requires max_model_len=32768"
assert self.runner.block_size == 1, "AITER MLA" \ assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1." "only supports block size 1."
def _get_paged_kv_tensors( def _get_paged_kv_tensors(
self, block_table: torch.Tensor, self, block_table: torch.Tensor,
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
page_size = self.runner.block_size page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size block_table_bounds = (seq_lens + page_size - 1) // page_size
mask = (torch.arange(block_table.size(1), mask = (torch.arange(block_table.size(1),
......
...@@ -32,9 +32,16 @@ class KVCacheBlocks: ...@@ -32,9 +32,16 @@ class KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks.""" """Creates a new KVCacheBlocks instance with no blocks."""
return cls([]) return cls([])
def get_block_ids(self) -> list[int]: def get_block_ids(self) -> list[list[int]]:
"""Converts the KVCacheBlocks instance to a list of block IDs.""" """
return [block.block_id for block in self.blocks] Converts the KVCacheBlocks instance to block_ids.
Returns:
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups (only 1 group now)
* each inner list contains the block_ids of the blocks in that group
"""
return [[block.block_id for block in self.blocks]]
def get_unhashed_block_ids(self) -> list[int]: def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance.""" """Get block_ids of unhashed blocks from KVCacheBlocks instance."""
...@@ -300,9 +307,9 @@ class KVCacheManager: ...@@ -300,9 +307,9 @@ class KVCacheManager:
self, self,
request: Request, request: Request,
num_running_requests: int, num_running_requests: int,
) -> int: ) -> list[int]:
"""Calculate the number of common prefix blocks shared by all requests """Calculate the number of common prefix blocks shared by all requests
in the RUNNING state. in the RUNNING state for each kv cache group.
The function determines this by selecting any request and iterating The function determines this by selecting any request and iterating
through its blocks. A block is considered a common prefix block if its through its blocks. A block is considered a common prefix block if its
...@@ -332,11 +339,14 @@ class KVCacheManager: ...@@ -332,11 +339,14 @@ class KVCacheManager:
requests in the current step. requests in the current step.
Returns: Returns:
int: The number of common prefix blocks. list[int]: The number of common prefix blocks for each kv cache
group.
""" """
assert request.status == RequestStatus.RUNNING assert request.status == RequestStatus.RUNNING
return self.single_type_manager.get_num_common_prefix_blocks( return [
request.request_id, num_running_requests) self.single_type_manager.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
]
def free_block_hashes(self, request: Request) -> None: def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request. """Discard the block hashes for the request.
...@@ -354,10 +364,8 @@ class KVCacheManager: ...@@ -354,10 +364,8 @@ class KVCacheManager:
""" """
return self.block_pool.take_events() return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[int]: def get_block_ids(self, request_id: str) -> list[list[int]]:
"""Get the block ids of a request.""" """Get the block ids of a request."""
assert request_id in self.single_type_manager.req_to_blocks assert request_id in self.single_type_manager.req_to_blocks
return [ return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id]
block.block_id ).get_block_ids()
for block in self.single_type_manager.req_to_blocks[request_id]
]
...@@ -577,14 +577,12 @@ def create_kv_cache_group_specs( ...@@ -577,14 +577,12 @@ def create_kv_cache_group_specs(
""" """
kv_cache_groups = [] kv_cache_groups = []
for layer_names_one_group in grouped_layer_names: for layer_names_one_group in grouped_layer_names:
layer_spec = kv_cache_spec[layer_names_one_group[0]] layer_specs = [
assert all( kv_cache_spec[layer_name] for layer_name in layer_names_one_group
kv_cache_spec[layer_name] == layer_spec ]
for layer_name in layer_names_one_group[1:]), ( merged_layer_spec = layer_specs[0].merge(layer_specs)
"All layers in the same KV cache group must share the same "
"KVCacheSpec.")
kv_cache_groups.append( kv_cache_groups.append(
KVCacheGroupSpec(layer_names_one_group, layer_spec)) KVCacheGroupSpec(layer_names_one_group, merged_layer_spec))
return kv_cache_groups return kv_cache_groups
...@@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): ...@@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
head_size=spec.head_size, head_size=spec.head_size,
dtype=spec.dtype, dtype=spec.dtype,
use_mla=spec.use_mla, use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
) )
......
...@@ -26,7 +26,7 @@ class NewRequestData: ...@@ -26,7 +26,7 @@ class NewRequestData:
mm_hashes: list[str] mm_hashes: list[str]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams sampling_params: SamplingParams
block_ids: list[int] block_ids: list[list[int]]
num_computed_tokens: int num_computed_tokens: int
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
...@@ -34,7 +34,7 @@ class NewRequestData: ...@@ -34,7 +34,7 @@ class NewRequestData:
def from_request( def from_request(
cls, cls,
request: Request, request: Request,
block_ids: list[int], block_ids: list[list[int]],
) -> NewRequestData: ) -> NewRequestData:
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
...@@ -85,7 +85,7 @@ class CachedRequestData: ...@@ -85,7 +85,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs. # request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool resumed_from_preemption: bool
new_token_ids: list[int] new_token_ids: list[int]
new_block_ids: list[int] new_block_ids: list[list[int]]
num_computed_tokens: int num_computed_tokens: int
@classmethod @classmethod
...@@ -94,7 +94,7 @@ class CachedRequestData: ...@@ -94,7 +94,7 @@ class CachedRequestData:
request: Request, request: Request,
resumed_from_preemption: bool, resumed_from_preemption: bool,
new_token_ids: list[int], new_token_ids: list[int],
new_block_ids: list[int], new_block_ids: list[list[int]],
) -> CachedRequestData: ) -> CachedRequestData:
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
...@@ -131,9 +131,9 @@ class SchedulerOutput: ...@@ -131,9 +131,9 @@ class SchedulerOutput:
# E.g., if a request has [0, 1], it could mean the vision encoder needs # E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step. # to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: dict[str, list[int]] scheduled_encoder_inputs: dict[str, list[int]]
# Number of common prefix blocks for all requests. # Number of common prefix blocks for all requests in each KV cache group.
# This can be used for cascade attention. # This can be used for cascade attention.
num_common_prefix_blocks: int num_common_prefix_blocks: list[int]
# Request IDs that are finished in between the previous and the current # Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests # steps. This is used to notify the workers about the finished requests
......
...@@ -173,7 +173,7 @@ class Scheduler(SchedulerInterface): ...@@ -173,7 +173,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding. # uses structured decoding.
structured_output_request_ids: dict[str, int] = {} structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, list[int]] = {} req_to_new_block_ids: dict[str, list[list[int]]] = {}
num_scheduled_tokens: dict[str, int] = {} num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens token_budget = self.max_num_scheduled_tokens
# Encoder-related. # Encoder-related.
...@@ -477,7 +477,8 @@ class Scheduler(SchedulerInterface): ...@@ -477,7 +477,8 @@ class Scheduler(SchedulerInterface):
# Get the longest common prefix among all requests in the running queue. # Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention. # This can be potentially used for cascade attention.
num_common_prefix_blocks = 0 num_common_prefix_blocks = [0] * len(
self.kv_cache_config.kv_cache_groups)
if self.running: if self.running:
any_request = self.running[0] any_request = self.running[0]
num_common_prefix_blocks = ( num_common_prefix_blocks = (
...@@ -564,7 +565,7 @@ class Scheduler(SchedulerInterface): ...@@ -564,7 +565,7 @@ class Scheduler(SchedulerInterface):
request: Request, request: Request,
num_scheduled_tokens: int, num_scheduled_tokens: int,
num_scheduled_spec_tokens: int, num_scheduled_spec_tokens: int,
new_block_ids: list[int], new_block_ids: list[list[int]],
resumed_from_preemption: bool, resumed_from_preemption: bool,
) -> CachedRequestData: ) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
...@@ -939,7 +940,9 @@ class Scheduler(SchedulerInterface): ...@@ -939,7 +940,9 @@ class Scheduler(SchedulerInterface):
""" """
if self.connector is None: if self.connector is None:
return False, None return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id) assert len(self.kv_cache_config.kv_cache_groups
) == 1, "KV connector only supports one KV cache group now"
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]
return self.connector.request_finished(request, block_ids) return self.connector.request_finished(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool: def _update_waiting_for_remote_kv(self, request: Request) -> bool:
...@@ -956,9 +959,10 @@ class Scheduler(SchedulerInterface): ...@@ -956,9 +959,10 @@ class Scheduler(SchedulerInterface):
""" """
if request.request_id not in self.finished_recving_kv_req_ids: if request.request_id not in self.finished_recving_kv_req_ids:
return False return False
assert len(self.kv_cache_config.kv_cache_groups
) == 1, "KV connector only supports one KV cache group now"
# Now that the blocks are ready, actually cache them. # Now that the blocks are ready, actually cache them.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id) block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]
num_computed_tokens = len(block_ids) * self.block_size num_computed_tokens = len(block_ids) * self.block_size
if num_computed_tokens == request.num_tokens: if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1 num_computed_tokens -= 1
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import torch import torch
from typing_extensions import Self
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -53,6 +56,16 @@ class KVCacheSpec: ...@@ -53,6 +56,16 @@ class KVCacheSpec:
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
"All layers in the same KV cache group must share the same "
"type_id.")
return copy.deepcopy(specs[0])
@dataclass @dataclass
class AttentionSpec(KVCacheSpec): class AttentionSpec(KVCacheSpec):
...@@ -71,6 +84,16 @@ class AttentionSpec(KVCacheSpec): ...@@ -71,6 +84,16 @@ class AttentionSpec(KVCacheSpec):
@dataclass @dataclass
class FullAttentionSpec(AttentionSpec): class FullAttentionSpec(AttentionSpec):
sliding_window: Optional[int] = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
@property @property
def type_id(self) -> str: def type_id(self) -> str:
...@@ -80,6 +103,25 @@ class FullAttentionSpec(AttentionSpec): ...@@ -80,6 +103,25 @@ class FullAttentionSpec(AttentionSpec):
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec = super().merge(specs)
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
if len(sliding_window) == 0:
merged_spec.sliding_window = None
elif len(sliding_window) == 1:
merged_spec.sliding_window = sliding_window.pop()
else:
raise ValueError(
"All sliding window layers in the same KV cache group "
"must have the same window size.")
return merged_spec
@dataclass @dataclass
class SlidingWindowSpec(AttentionSpec): class SlidingWindowSpec(AttentionSpec):
......
...@@ -4,6 +4,8 @@ import numpy as np ...@@ -4,6 +4,8 @@ import numpy as np
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -96,3 +98,48 @@ class BlockTable: ...@@ -96,3 +98,48 @@ class BlockTable:
def get_numpy_array(self) -> np.ndarray: def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table.""" """Returns the numpy array of the block table."""
return self.block_table_np return self.block_table_np
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, kv_cache_config: KVCacheConfig) -> None:
max_num_blocks_per_req = [
cdiv(max_model_len, g.kv_cache_spec.block_size)
for g in kv_cache_config.kv_cache_groups
]
self.block_tables = [
BlockTable(max_num_reqs, max_num_blocks_per_req[i],
max_num_batched_tokens, pin_memory, device)
for i in range(len(kv_cache_config.kv_cache_groups))
]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def commit(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit(num_reqs)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]
...@@ -11,10 +11,11 @@ from vllm.lora.request import LoRARequest ...@@ -11,10 +11,11 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values from vllm.utils import swap_dict_values
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import MultiGroupBlockTable
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -29,7 +30,7 @@ class CachedRequestState: ...@@ -29,7 +30,7 @@ class CachedRequestState:
sampling_params: SamplingParams sampling_params: SamplingParams
generator: Optional[torch.Generator] generator: Optional[torch.Generator]
block_ids: list[int] block_ids: list[list[int]]
num_computed_tokens: int num_computed_tokens: int
output_token_ids: list[int] output_token_ids: list[int]
...@@ -58,15 +59,14 @@ class InputBatch: ...@@ -58,15 +59,14 @@ class InputBatch:
self, self,
max_num_reqs: int, max_num_reqs: int,
max_model_len: int, max_model_len: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int, max_num_batched_tokens: int,
device: torch.device, device: torch.device,
pin_memory: bool, pin_memory: bool,
vocab_size: int, vocab_size: int,
kv_cache_config: KVCacheConfig,
): ):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.device = device self.device = device
self.pin_memory = pin_memory self.pin_memory = pin_memory
...@@ -99,12 +99,13 @@ class InputBatch: ...@@ -99,12 +99,13 @@ class InputBatch:
self.num_computed_tokens_cpu_tensor.numpy() self.num_computed_tokens_cpu_tensor.numpy()
# Block table. # Block table.
self.block_table = BlockTable( self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs, max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req, max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory, pin_memory=pin_memory,
device=device, device=device,
kv_cache_config=kv_cache_config,
) )
# Sampling-related. # Sampling-related.
......
This diff is collapsed.
...@@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.kv_caches: list[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output) # req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# self.input_batch: InputBatch # Persistent batch.
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.vocab_size,
)
# Cached torch/numpy tensor # Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer. # The pytorch tensor and numpy array share the same buffer.
...@@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.block_table_cpu = torch.zeros( self.block_table_cpu = torch.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req), (self.max_num_reqs, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype, dtype=torch.int32,
device="cpu") device="cpu")
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
...@@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): We use torch.index_select instead of np.take here # NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large # because torch.index_select is much faster than np.take for large
# tensors. # tensors.
block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size, np.add(block_numbers * self.block_size,
block_offsets, block_offsets,
out=self.input_batch.block_table. out=self.input_batch.block_table[0].
slot_mapping_np[:total_num_scheduled_tokens]) slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata. # Prepare the attention metadata.
...@@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.position_ids = self.positions_cpu[: self.position_ids = self.positions_cpu[:
padded_total_num_scheduled_tokens].to( padded_total_num_scheduled_tokens].to(
self.device) self.device)
self.input_batch.block_table.slot_mapping_cpu[ self.input_batch.block_table[0].slot_mapping_cpu[
total_num_scheduled_tokens:] = _PAD_SLOT_ID total_num_scheduled_tokens:] = _PAD_SLOT_ID
slot_mapping = ( slot_mapping = (
self.input_batch.block_table. self.input_batch.block_table[0].
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
self.device)) self.device))
block_tables = self.block_table_cpu[:self.max_num_reqs] block_tables = self.block_table_cpu[:self.max_num_reqs]
block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
block_tables = block_tables.to(self.device) block_tables = block_tables.to(self.device)
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
self.device) self.device)
...@@ -1263,6 +1254,18 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1263,6 +1254,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not " "Hybrid models with more than one KV cache type are not "
"supported yet.") "supported yet.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config,
)
assert self.block_table_cpu.dtype == self.input_batch.block_table[
0].get_cpu_tensor().dtype
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups: for kv_cache_group in kv_cache_config.kv_cache_groups:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment