Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
646d62f6
Unverified
Commit
646d62f6
authored
Jun 09, 2025
by
Nick Hill
Committed by
GitHub
Jun 10, 2025
Browse files
[Core] Use tuple for kv cache group block ids (#19175)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
6cd4ae8a
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
140 additions
and
142 deletions
+140
-142
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+22
-22
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+4
-4
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+1
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+2
-2
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+4
-4
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+53
-48
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+18
-20
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+4
-4
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+4
-8
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+24
-26
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+3
-2
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+1
-1
No files found.
tests/v1/core/test_prefix_caching.py
View file @
646d62f6
...
...
@@ -117,7 +117,7 @@ def test_prefill(hash_algo):
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
# Check full block metadata
parent_block_hash
=
None
...
...
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
]
,
)
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
assert
blocks
.
get_block_ids
()
==
(
[
5
]
,
)
for
block
in
computed_blocks
.
blocks
[
0
]:
assert
block
.
ref_cnt
==
2
...
...
@@ -175,13 +175,13 @@ def test_prefill(hash_algo):
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
]
,
)
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
6
]
]
assert
blocks
.
get_block_ids
()
==
(
[
6
]
,
)
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
...
...
@@ -205,7 +205,7 @@ def test_prefill(hash_algo):
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
# This block ID order also checks the eviction order.
assert
blocks
.
get_block_ids
()
==
[
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
]
assert
blocks
.
get_block_ids
()
==
(
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
,
)
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
...
...
@@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
[
9
,
10
,
11
,
12
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
[
9
,
10
,
11
,
12
]
)
# Check full block metadata
parent_block_hash
=
None
...
...
@@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
],
[
0
,
6
,
7
],
[
0
,
10
,
11
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
],
[
0
,
6
,
7
],
[
0
,
10
,
11
]
)
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
13
],
[
14
],
[
15
]
]
assert
blocks
.
get_block_ids
()
==
(
[
13
],
[
14
],
[
15
]
)
for
block_per_group
in
computed_blocks
.
blocks
:
for
block
in
block_per_group
:
if
block
!=
manager
.
block_pool
.
null_block
:
...
...
@@ -374,7 +374,7 @@ def test_prefill_plp():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]]
# Check full block metadata
...
...
@@ -400,13 +400,13 @@ def test_prefill_plp():
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
]
,
)
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
assert
blocks
.
get_block_ids
()
==
(
[
5
]
,
)
for
block
in
computed_blocks
.
blocks
[
0
]:
assert
block
.
ref_cnt
==
2
...
...
@@ -444,7 +444,7 @@ def test_prefill_plp():
block_ids
=
blocks
.
get_block_ids
()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]]
==
req0_block_hashes
assert
block_ids
!=
[
[
1
,
2
,
3
,
4
]
]
assert
block_ids
!=
(
[
1
,
2
,
3
,
4
]
,
)
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
...
...
@@ -474,7 +474,7 @@ def test_decode():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
# Append slots without allocating a new block.
req0
.
num_computed_tokens
=
55
...
...
@@ -546,12 +546,12 @@ def test_evict():
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
]
,
)
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
10
]
]
assert
blocks
.
get_block_ids
()
==
(
[
10
]
,
)
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
...
...
@@ -865,7 +865,7 @@ def test_mm_prefix_caching():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
...
...
@@ -926,7 +926,7 @@ def test_cache_key_salting():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
...
...
@@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
...
...
@@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
assert
blocks
.
get_block_ids
()
==
(
[
5
]
,
)
# Failed to reset prefix cache because some blocks are not freed yet.
assert
not
manager
.
reset_prefix_cache
()
...
...
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
646d62f6
...
...
@@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
block_ids
=
[
[
0
]
]
,
# block_ids should be
list
[list[int]]
block_ids
=
(
[
0
]
,
)
,
# block_ids should be
tuple
[list[int]]
num_computed_tokens
=
0
,
lora_request
=
None
,
))
...
...
@@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
# This is safe since we currently only use single KV cache groups
block_table
=
multi_group_block_table
[
0
]
# req_state.block_ids is now
list
[list[int]] for MultiGroupBlockTable
# req_state.block_ids is now
tuple
[list[int]
, ...
] for MultiGroupBlockTable
# Extract the first group's block IDs
if
isinstance
(
req_state
.
block_ids
[
0
],
list
):
# New format:
list
[list[int]] - extract first group
# New format:
tuple
[list[int]
, ...
] - extract first group
req_block_ids
=
req_state
.
block_ids
[
0
]
else
:
# Legacy format: list[int] - use directly
...
...
@@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
req_id
=
req_id
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_block_ids
=
[
[]
]
,
new_block_ids
=
(
[]
,
)
,
num_computed_tokens
=
0
,
)
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
646d62f6
...
...
@@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_positions
=
[],
block_ids
=
[
[]
]
,
block_ids
=
(
[]
,
)
,
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
,
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
646d62f6
...
...
@@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
block_ids
=
[
[
0
]
]
,
block_ids
=
(
[
0
]
,
)
,
num_computed_tokens
=
0
,
lora_request
=
None
,
))
...
...
@@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
req_id
=
req_id
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_block_ids
=
[
[]
]
,
new_block_ids
=
(
[]
,
)
,
num_computed_tokens
=
0
,
)
...
...
vllm/v1/core/block_pool.py
View file @
646d62f6
...
...
@@ -89,8 +89,8 @@ class BlockPool:
BlockHashWithGroupId
(
block_hash
,
group_id
))
if
not
cached_blocks_one_group
:
return
None
first_block
_id
=
next
(
iter
(
cached_blocks_one_group
))
cached_blocks
.
append
(
cached_blocks_one_group
[
first_block
_id
]
)
first_block
=
next
(
iter
(
cached_blocks_one_group
.
values
()
))
cached_blocks
.
append
(
first_block
)
return
cached_blocks
def
cache_full_blocks
(
...
...
@@ -260,7 +260,7 @@ class BlockPool:
return
True
return
False
def
touch
(
self
,
blocks
:
list
[
list
[
KVCacheBlock
]])
->
None
:
def
touch
(
self
,
blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
])
->
None
:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
...
...
@@ -299,7 +299,7 @@ class BlockPool:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks
=
(
self
.
num_gpu_blocks
-
self
.
get_num_free_blocks
()
)
num_used_blocks
=
self
.
num_gpu_blocks
-
self
.
get_num_free_blocks
()
if
num_used_blocks
!=
1
:
# The null block is always marked as used
logger
.
warning
(
"Failed to reset prefix cache because some "
...
...
vllm/v1/core/kv_cache_coordinator.py
View file @
646d62f6
...
...
@@ -5,8 +5,7 @@ from typing import Callable, Optional
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
SingleTypeKVCacheManager
,
get_manager_for_kv_cache_spec
)
FullAttentionManager
,
get_manager_for_kv_cache_spec
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
KVCacheConfig
from
vllm.v1.request
import
Request
...
...
@@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC):
self
.
block_pool
=
BlockPool
(
kv_cache_config
.
num_blocks
,
enable_caching
,
enable_kv_cache_events
)
self
.
single_type_managers
:
list
[
SingleTypeKVCacheManager
]
=
[]
# Needs special handling for find_longest_cache_hit if eagle is enabled
self
.
use_eagle
=
use_eagle
for
i
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
)):
kv_cache_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
i
].
kv_cache_spec
self
.
single_type_managers
.
append
(
get_manager_for_kv_cache_spec
(
kv_cache_spec
=
kv_cache_spec
,
block_pool
=
self
.
block_pool
,
kv_cache_group_id
=
i
,
caching_hash_fn
=
caching_hash_fn
,
))
self
.
single_type_managers
=
tuple
(
get_manager_for_kv_cache_spec
(
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
,
block_pool
=
self
.
block_pool
,
kv_cache_group_id
=
i
,
caching_hash_fn
=
caching_hash_fn
,
)
for
i
,
kv_cache_group
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
))
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
new_computed_blocks
:
list
[
list
[
KVCacheBlock
]])
->
int
:
new_computed_blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
])
->
int
:
"""
Get the number of blocks needed to be allocated for the request.
...
...
@@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC):
def
save_new_computed_blocks
(
self
,
request_id
:
str
,
new_computed_blocks
:
list
[
list
[
KVCacheBlock
]])
->
None
:
new_computed_blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
])
->
None
:
"""
Add the new computed blocks to the request.
...
...
@@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC):
new_computed_blocks
[
i
])
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
list
[
list
[
KVCacheBlock
]]:
num_tokens
:
int
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
...
...
@@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC):
Returns:
The new allocated blocks.
"""
new_blocks
=
[]
for
manager
in
self
.
single_type_managers
:
new_blocks
.
append
(
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
))
return
new_blocks
return
tuple
(
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
)
for
manager
in
self
.
single_type_managers
)
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
num_computed_tokens
:
int
)
->
None
:
...
...
@@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC):
for
manager
in
self
.
single_type_managers
:
manager
.
remove_skipped_blocks
(
request_id
,
num_computed_tokens
)
def
get_blocks
(
self
,
request_id
:
str
)
->
list
[
list
[
KVCacheBlock
]]:
def
get_blocks
(
self
,
request_id
:
str
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
"""
Get the blocks for the request.
"""
return
[
return
tuple
(
manager
.
req_to_blocks
.
get
(
request_id
)
or
[]
for
manager
in
self
.
single_type_managers
]
for
manager
in
self
.
single_type_managers
)
@
abstractmethod
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
)
->
tuple
[
tuple
[
list
[
KVCacheBlock
],
...],
int
]:
pass
...
...
@@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
)
->
tuple
[
tuple
[
list
[
KVCacheBlock
],
...],
int
]:
hit_blocks
=
self
.
single_type_managers
[
0
].
find_longest_cache_hit
(
block_hashes
=
block_hashes
,
max_length
=
max_cache_hit_length
,
...
...
@@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now."
)
if
max
(
self
.
full_attention_group_ids
)
<
min
(
self
.
other_group_ids
):
self
.
full_attn_first
=
True
elif
max
(
self
.
other_group_ids
)
<
min
(
self
.
full_attention_group_ids
):
self
.
full_attn_first
=
False
else
:
raise
ValueError
(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks."
)
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
)
->
tuple
[
tuple
[
list
[
KVCacheBlock
]
,
...
],
int
]:
"""
Find the longest cache hit for the request.
...
...
@@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
))
hit_length
=
len
(
hit_blocks_other_attn
[
0
])
*
self
.
other_block_size
# NOTE: the prefix cache hit length must be a multipl
y
of block_size as
# NOTE: the prefix cache hit length must be a multipl
e
of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multipl
y
of the block size of
# of other attention is ensured to be a multipl
e
of the block size of
# full attention layers in current implementation, because hit_length is
# a multipl
y
of other attention's block size, and other attention's
# block size is a multipl
y
of full attention's block size (verified in
# a multipl
e
of other attention's block size, and other attention's
# block size is a multipl
e
of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert
hit_length
%
self
.
full_attention_block_size
==
0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for
i
in
range
(
len
(
hit_blocks_full_attn
)):
del
hit_blocks_full_attn
[
i
][
hit_length
//
self
.
full_attention_block_size
:]
for
group_hit_blocks
in
hit_blocks_full_attn
:
del
group_hit_blocks
[
hit_length
//
self
.
full_attention_block_size
:]
# Merge the hit blocks of full attention and other attention.
hit_blocks
=
hit_blocks_other_attn
for
group_id
,
blocks
in
enumerate
(
hit_blocks_full_attn
):
# NOTE: there is only one full attention group in most cases. So
# the time complexity of insert is fine.
hit_blocks
.
insert
(
group_id
,
blocks
)
if
self
.
full_attn_first
:
hit_blocks
=
hit_blocks_full_attn
+
hit_blocks_other_attn
else
:
hit_blocks
=
hit_blocks_other_attn
+
hit_blocks_full_attn
return
hit_blocks
,
hit_length
...
...
@@ -351,8 +358,6 @@ def get_kv_cache_coordinator(
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
else
:
return
HybridKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
return
HybridKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
vllm/v1/core/kv_cache_manager.py
View file @
646d62f6
...
...
@@ -21,11 +21,11 @@ logger = init_logger(__name__)
@
dataclass
class
KVCacheBlocks
:
"""
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks
:
list
[
list
[
KVCacheBlock
]]
blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
]
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
...
...
@@ -37,21 +37,19 @@ class KVCacheBlocks:
def
__add__
(
self
,
other
:
"KVCacheBlocks"
)
->
"KVCacheBlocks"
:
"""Adds two KVCacheBlocks instances."""
return
KVCacheBlocks
(
[
blk1
+
blk2
for
blk1
,
blk2
in
zip
(
self
.
blocks
,
other
.
blocks
)])
tuple
(
blk1
+
blk2
for
blk1
,
blk2
in
zip
(
self
.
blocks
,
other
.
blocks
)))
def
get_block_ids
(
self
)
->
list
[
list
[
int
]]:
def
get_block_ids
(
self
)
->
tuple
[
list
[
int
]
,
...
]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
list
[list[int]]: A t
wo-level
list where
* the outer
list
corresponds to KV cache groups
tuple
[list[int]
, ...
]: A t
uple of
list
s
where
* the outer
tuple
corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
"""
block_ids
=
[]
for
group
in
self
.
blocks
:
block_ids
.
append
([
blk
.
block_id
for
blk
in
group
])
return
block_ids
return
tuple
([
blk
.
block_id
for
blk
in
group
]
for
group
in
self
.
blocks
)
def
get_unhashed_block_ids
(
self
)
->
list
[
int
]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
...
...
@@ -63,7 +61,7 @@ class KVCacheBlocks:
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
"""Creates a new KVCacheBlocks instance with no blocks."""
return
KVCacheBlocks
(
[
[]
for
_
in
range
(
len
(
self
.
blocks
))
]
)
return
KVCacheBlocks
(
tuple
(
[]
for
_
in
range
(
len
(
self
.
blocks
))
)
)
class
KVCacheManager
:
...
...
@@ -232,9 +230,8 @@ class KVCacheManager:
if
new_computed_blocks
is
not
None
:
new_computed_block_list
=
new_computed_blocks
.
blocks
else
:
new_computed_block_list
=
[
[]
for
_
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
))
]
new_computed_block_list
=
tuple
(
[]
for
_
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
)))
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
...
...
@@ -267,7 +264,7 @@ class KVCacheManager:
if
self
.
enable_caching
:
self
.
block_pool
.
touch
(
new_computed_block_list
)
else
:
assert
all
(
not
blocks
for
blocks
in
new_computed_block_list
),
(
assert
not
any
(
new_computed_block_list
),
(
"Computed blocks should be empty when "
"prefix caching is disabled"
)
...
...
@@ -378,17 +375,18 @@ class KVCacheManager:
"""
return
self
.
block_pool
.
take_events
()
def
get_block_ids
(
self
,
request_id
:
str
)
->
list
[
list
[
int
]]:
def
get_block_ids
(
self
,
request_id
:
str
)
->
tuple
[
list
[
int
]
,
...
]:
"""Get the block ids of a request."""
return
KVCacheBlocks
(
self
.
coordinator
.
get_blocks
(
request_id
)).
get_block_ids
()
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
num_computed_tokens
:
int
)
->
None
:
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
"""Cache the blocks for the request."""
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
self
.
coordinator
.
cache_blocks
(
request
,
block_hashes
,
num_computed_tokens
)
def
create_empty_block_list
(
self
)
->
KVCacheBlocks
:
"""Creates a new KVCacheBlocks instance with no blocks."""
return
KVCacheBlocks
([[]
for
_
in
range
(
self
.
num_kv_cache_groups
)])
return
KVCacheBlocks
(
tuple
([]
for
_
in
range
(
self
.
num_kv_cache_groups
)))
vllm/v1/core/sched/output.py
View file @
646d62f6
...
...
@@ -27,7 +27,7 @@ class NewRequestData:
mm_hashes
:
list
[
str
]
mm_positions
:
list
[
PlaceholderRange
]
sampling_params
:
SamplingParams
block_ids
:
list
[
list
[
int
]]
block_ids
:
tuple
[
list
[
int
]
,
...
]
num_computed_tokens
:
int
lora_request
:
Optional
[
LoRARequest
]
...
...
@@ -35,7 +35,7 @@ class NewRequestData:
def
from_request
(
cls
,
request
:
Request
,
block_ids
:
list
[
list
[
int
]],
block_ids
:
tuple
[
list
[
int
]
,
...
],
)
->
NewRequestData
:
return
cls
(
req_id
=
request
.
request_id
,
...
...
@@ -86,7 +86,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption
:
bool
new_token_ids
:
list
[
int
]
new_block_ids
:
list
[
list
[
int
]]
new_block_ids
:
tuple
[
list
[
int
]
,
...
]
num_computed_tokens
:
int
@
classmethod
...
...
@@ -95,7 +95,7 @@ class CachedRequestData:
request
:
Request
,
resumed_from_preemption
:
bool
,
new_token_ids
:
list
[
int
],
new_block_ids
:
list
[
list
[
int
]],
new_block_ids
:
tuple
[
list
[
int
]
,
...
],
)
->
CachedRequestData
:
return
cls
(
req_id
=
request
.
request_id
,
...
...
vllm/v1/core/sched/scheduler.py
View file @
646d62f6
...
...
@@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding.
structured_output_request_ids
:
dict
[
str
,
int
]
=
{}
req_to_new_block_ids
:
dict
[
str
,
list
[
list
[
int
]]]
=
{}
req_to_new_block_ids
:
dict
[
str
,
tuple
[
list
[
int
]
,
...
]]
=
{}
num_scheduled_tokens
:
dict
[
str
,
int
]
=
{}
token_budget
=
self
.
max_num_scheduled_tokens
# Encoder-related.
...
...
@@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface):
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
request
.
num_computed_tokens
=
num_computed_tokens
# Count the number of pr
i
fix cached tokens.
# Count the number of pr
e
fix cached tokens.
if
request
.
num_cached_tokens
<
0
:
request
.
num_cached_tokens
=
num_computed_tokens
# Encoder-related.
...
...
@@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface):
request
:
Request
,
num_scheduled_tokens
:
int
,
num_scheduled_spec_tokens
:
int
,
new_block_ids
:
list
[
list
[
int
]],
new_block_ids
:
tuple
[
list
[
int
]
,
...
],
resumed_from_preemption
:
bool
,
)
->
CachedRequestData
:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
...
...
@@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens
=
min
(
num_computed_tokens
,
request
.
num_tokens
)
if
num_computed_tokens
==
request
.
num_tokens
:
num_computed_tokens
-=
1
self
.
kv_cache_manager
.
cache_blocks
(
request
,
self
.
kv_cache_manager
.
req_to_block_hashes
[
request
.
request_id
],
num_computed_tokens
,
)
self
.
kv_cache_manager
.
cache_blocks
(
request
,
num_computed_tokens
)
# Update the request state for scheduling.
request
.
num_computed_tokens
=
num_computed_tokens
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
646d62f6
...
...
@@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. The prefix should be a common prefix hit for all the
...
...
@@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC):
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
[
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]
]
for block size 4
(
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]
)
for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
...
...
@@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager):
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
assert
isinstance
(
kv_cache_spec
,
FullAttentionSpec
),
(
"FullAttentionManager can only be used for full attention groups"
)
computed_blocks
:
list
[
list
[
KVCacheBlock
]]
=
[
[]
for
_
in
range
(
len
(
kv_cache_group_ids
))
]
computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...]
=
tuple
(
[]
for
_
in
range
(
len
(
kv_cache_group_ids
)))
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
for
i
in
range
(
max_num_blocks
):
block_hash
=
block_hashes
[
i
]
for
i
,
block_hash
in
zip
(
range
(
max_num_blocks
),
block_hashes
):
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if
cached_block
:
=
block_pool
.
get_cached_block
(
block_hash
,
kv_cache_group_ids
):
for
j
in
range
(
len
(
kv_cache_group_ids
)
):
computed
_blocks
[
j
]
.
append
(
cached
_block
[
j
]
)
for
computed
,
cached
in
zip
(
computed_blocks
,
cached_block
):
computed
.
append
(
cached
)
else
:
break
if
use_eagle
and
len
(
computed_blocks
[
0
]
)
>
0
:
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
computed
_blocks
[
j
]
.
pop
()
if
use_eagle
and
computed_blocks
[
0
]:
for
computed
in
computed_blocks
:
computed
.
pop
()
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
...
...
@@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
assert
isinstance
(
kv_cache_spec
,
SlidingWindowSpec
),
(
"SlidingWindowManager can only be used for sliding window groups"
)
...
...
@@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
computed_blocks
=
[
[
block_pool
.
null_block
]
*
max_num_blocks
for
_
in
range
(
len
(
kv_cache_group_ids
))
]
computed_blocks
=
tuple
(
[
block_pool
.
null_block
]
*
max_num_blocks
for
_
in
range
(
len
(
kv_cache_group_ids
))
)
num_contiguous_blocks
=
0
match_found
=
False
# Search from right to left and early stop when a match is found.
for
i
in
range
(
max_num_blocks
-
1
,
-
1
,
-
1
):
if
cached_block
:
=
block_pool
.
get_cached_block
(
block_hashes
[
i
],
kv_cache_group_ids
):
for
j
in
range
(
len
(
kv_cache_group_ids
)
):
computed
_blocks
[
j
]
[
i
]
=
cached
_block
[
j
]
for
computed
,
cached
in
zip
(
computed_blocks
,
cached_block
):
computed
[
i
]
=
cached
num_contiguous_blocks
+=
1
if
(
num_contiguous_blocks
>=
sliding_window_contiguous_blocks
)
:
if
num_contiguous_blocks
>=
sliding_window_contiguous_blocks
:
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
del
computed
_blocks
[
j
]
[
i
+
num_contiguous_blocks
:]
for
computed
in
computed_blocks
:
del
computed
[
i
+
num_contiguous_blocks
:]
match_found
=
True
break
else
:
...
...
@@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if
not
match_found
:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
del
computed
_blocks
[
j
]
[
num_contiguous_blocks
:]
if
use_eagle
and
len
(
computed_blocks
[
0
]
)
>
0
:
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
computed
_blocks
[
j
]
.
pop
()
for
computed
in
computed_blocks
:
del
computed
[
num_contiguous_blocks
:]
if
use_eagle
and
computed_blocks
[
0
]:
for
computed
in
computed_blocks
:
computed
.
pop
()
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
...
...
vllm/v1/worker/block_table.py
View file @
646d62f6
...
...
@@ -112,11 +112,12 @@ class MultiGroupBlockTable:
for
block_size
in
block_sizes
]
def
append_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
def
append_row
(
self
,
block_ids
:
tuple
[
list
[
int
],
...],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
append_row
(
block_ids
[
i
],
row_idx
)
def
add_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
def
add_row
(
self
,
block_ids
:
tuple
[
list
[
int
]
,
...
],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
add_row
(
block_ids
[
i
],
row_idx
)
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
646d62f6
...
...
@@ -30,7 +30,7 @@ class CachedRequestState:
sampling_params
:
SamplingParams
generator
:
Optional
[
torch
.
Generator
]
block_ids
:
list
[
list
[
int
]]
block_ids
:
tuple
[
list
[
int
]
,
...
]
num_computed_tokens
:
int
output_token_ids
:
list
[
int
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment