Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
646d62f6
Unverified
Commit
646d62f6
authored
Jun 09, 2025
by
Nick Hill
Committed by
GitHub
Jun 10, 2025
Browse files
[Core] Use tuple for kv cache group block ids (#19175)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
6cd4ae8a
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
140 additions
and
142 deletions
+140
-142
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+22
-22
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+4
-4
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+1
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+2
-2
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+4
-4
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+53
-48
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+18
-20
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+4
-4
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+4
-8
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+24
-26
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+3
-2
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+1
-1
No files found.
tests/v1/core/test_prefix_caching.py
View file @
646d62f6
...
...
@@ -117,7 +117,7 @@ def test_prefill(hash_algo):
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
# Check full block metadata
parent_block_hash
=
None
...
...
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
]
,
)
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
assert
blocks
.
get_block_ids
()
==
(
[
5
]
,
)
for
block
in
computed_blocks
.
blocks
[
0
]:
assert
block
.
ref_cnt
==
2
...
...
@@ -175,13 +175,13 @@ def test_prefill(hash_algo):
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
]
,
)
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
6
]
]
assert
blocks
.
get_block_ids
()
==
(
[
6
]
,
)
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
...
...
@@ -205,7 +205,7 @@ def test_prefill(hash_algo):
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
# This block ID order also checks the eviction order.
assert
blocks
.
get_block_ids
()
==
[
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
]
assert
blocks
.
get_block_ids
()
==
(
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
,
)
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
...
...
@@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
[
9
,
10
,
11
,
12
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
[
9
,
10
,
11
,
12
]
)
# Check full block metadata
parent_block_hash
=
None
...
...
@@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
],
[
0
,
6
,
7
],
[
0
,
10
,
11
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
],
[
0
,
6
,
7
],
[
0
,
10
,
11
]
)
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
13
],
[
14
],
[
15
]
]
assert
blocks
.
get_block_ids
()
==
(
[
13
],
[
14
],
[
15
]
)
for
block_per_group
in
computed_blocks
.
blocks
:
for
block
in
block_per_group
:
if
block
!=
manager
.
block_pool
.
null_block
:
...
...
@@ -374,7 +374,7 @@ def test_prefill_plp():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]]
# Check full block metadata
...
...
@@ -400,13 +400,13 @@ def test_prefill_plp():
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
]
,
)
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
assert
blocks
.
get_block_ids
()
==
(
[
5
]
,
)
for
block
in
computed_blocks
.
blocks
[
0
]:
assert
block
.
ref_cnt
==
2
...
...
@@ -444,7 +444,7 @@ def test_prefill_plp():
block_ids
=
blocks
.
get_block_ids
()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]]
==
req0_block_hashes
assert
block_ids
!=
[
[
1
,
2
,
3
,
4
]
]
assert
block_ids
!=
(
[
1
,
2
,
3
,
4
]
,
)
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
...
...
@@ -474,7 +474,7 @@ def test_decode():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
# Append slots without allocating a new block.
req0
.
num_computed_tokens
=
55
...
...
@@ -546,12 +546,12 @@ def test_evict():
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
]
,
)
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
10
]
]
assert
blocks
.
get_block_ids
()
==
(
[
10
]
,
)
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
...
...
@@ -865,7 +865,7 @@ def test_mm_prefix_caching():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
...
...
@@ -926,7 +926,7 @@ def test_cache_key_salting():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
...
...
@@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
...
...
@@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
assert
blocks
.
get_block_ids
()
==
(
[
5
]
,
)
# Failed to reset prefix cache because some blocks are not freed yet.
assert
not
manager
.
reset_prefix_cache
()
...
...
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
646d62f6
...
...
@@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
block_ids
=
[
[
0
]
]
,
# block_ids should be
list
[list[int]]
block_ids
=
(
[
0
]
,
)
,
# block_ids should be
tuple
[list[int]]
num_computed_tokens
=
0
,
lora_request
=
None
,
))
...
...
@@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
# This is safe since we currently only use single KV cache groups
block_table
=
multi_group_block_table
[
0
]
# req_state.block_ids is now
list
[list[int]] for MultiGroupBlockTable
# req_state.block_ids is now
tuple
[list[int]
, ...
] for MultiGroupBlockTable
# Extract the first group's block IDs
if
isinstance
(
req_state
.
block_ids
[
0
],
list
):
# New format:
list
[list[int]] - extract first group
# New format:
tuple
[list[int]
, ...
] - extract first group
req_block_ids
=
req_state
.
block_ids
[
0
]
else
:
# Legacy format: list[int] - use directly
...
...
@@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
req_id
=
req_id
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_block_ids
=
[
[]
]
,
new_block_ids
=
(
[]
,
)
,
num_computed_tokens
=
0
,
)
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
646d62f6
...
...
@@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_positions
=
[],
block_ids
=
[
[]
]
,
block_ids
=
(
[]
,
)
,
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
,
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
646d62f6
...
...
@@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
block_ids
=
[
[
0
]
]
,
block_ids
=
(
[
0
]
,
)
,
num_computed_tokens
=
0
,
lora_request
=
None
,
))
...
...
@@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
req_id
=
req_id
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_block_ids
=
[
[]
]
,
new_block_ids
=
(
[]
,
)
,
num_computed_tokens
=
0
,
)
...
...
vllm/v1/core/block_pool.py
View file @
646d62f6
...
...
@@ -89,8 +89,8 @@ class BlockPool:
BlockHashWithGroupId
(
block_hash
,
group_id
))
if
not
cached_blocks_one_group
:
return
None
first_block
_id
=
next
(
iter
(
cached_blocks_one_group
))
cached_blocks
.
append
(
cached_blocks_one_group
[
first_block
_id
]
)
first_block
=
next
(
iter
(
cached_blocks_one_group
.
values
()
))
cached_blocks
.
append
(
first_block
)
return
cached_blocks
def
cache_full_blocks
(
...
...
@@ -260,7 +260,7 @@ class BlockPool:
return
True
return
False
def
touch
(
self
,
blocks
:
list
[
list
[
KVCacheBlock
]])
->
None
:
def
touch
(
self
,
blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
])
->
None
:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
...
...
@@ -299,7 +299,7 @@ class BlockPool:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks
=
(
self
.
num_gpu_blocks
-
self
.
get_num_free_blocks
()
)
num_used_blocks
=
self
.
num_gpu_blocks
-
self
.
get_num_free_blocks
()
if
num_used_blocks
!=
1
:
# The null block is always marked as used
logger
.
warning
(
"Failed to reset prefix cache because some "
...
...
vllm/v1/core/kv_cache_coordinator.py
View file @
646d62f6
...
...
@@ -5,8 +5,7 @@ from typing import Callable, Optional
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
SingleTypeKVCacheManager
,
get_manager_for_kv_cache_spec
)
FullAttentionManager
,
get_manager_for_kv_cache_spec
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
KVCacheConfig
from
vllm.v1.request
import
Request
...
...
@@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC):
self
.
block_pool
=
BlockPool
(
kv_cache_config
.
num_blocks
,
enable_caching
,
enable_kv_cache_events
)
self
.
single_type_managers
:
list
[
SingleTypeKVCacheManager
]
=
[]
# Needs special handling for find_longest_cache_hit if eagle is enabled
self
.
use_eagle
=
use_eagle
for
i
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
)):
kv_cache_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
i
].
kv_cache_spec
self
.
single_type_managers
.
append
(
self
.
single_type_managers
=
tuple
(
get_manager_for_kv_cache_spec
(
kv_cache_spec
=
kv_cache_spec
,
kv_cache_spec
=
kv_cache_
group
.
kv_cache_
spec
,
block_pool
=
self
.
block_pool
,
kv_cache_group_id
=
i
,
caching_hash_fn
=
caching_hash_fn
,
))
)
for
i
,
kv_cache_group
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
))
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
new_computed_blocks
:
list
[
list
[
KVCacheBlock
]])
->
int
:
new_computed_blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
])
->
int
:
"""
Get the number of blocks needed to be allocated for the request.
...
...
@@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC):
def
save_new_computed_blocks
(
self
,
request_id
:
str
,
new_computed_blocks
:
list
[
list
[
KVCacheBlock
]])
->
None
:
new_computed_blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
])
->
None
:
"""
Add the new computed blocks to the request.
...
...
@@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC):
new_computed_blocks
[
i
])
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
list
[
list
[
KVCacheBlock
]]:
num_tokens
:
int
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
...
...
@@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC):
Returns:
The new allocated blocks.
"""
new_blocks
=
[]
for
manager
in
self
.
single_type_managers
:
new_blocks
.
append
(
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
))
return
new_blocks
return
tuple
(
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
)
for
manager
in
self
.
single_type_managers
)
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
num_computed_tokens
:
int
)
->
None
:
...
...
@@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC):
for
manager
in
self
.
single_type_managers
:
manager
.
remove_skipped_blocks
(
request_id
,
num_computed_tokens
)
def
get_blocks
(
self
,
request_id
:
str
)
->
list
[
list
[
KVCacheBlock
]]:
def
get_blocks
(
self
,
request_id
:
str
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
"""
Get the blocks for the request.
"""
return
[
return
tuple
(
manager
.
req_to_blocks
.
get
(
request_id
)
or
[]
for
manager
in
self
.
single_type_managers
]
for
manager
in
self
.
single_type_managers
)
@
abstractmethod
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
)
->
tuple
[
tuple
[
list
[
KVCacheBlock
],
...],
int
]:
pass
...
...
@@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
)
->
tuple
[
tuple
[
list
[
KVCacheBlock
],
...],
int
]:
hit_blocks
=
self
.
single_type_managers
[
0
].
find_longest_cache_hit
(
block_hashes
=
block_hashes
,
max_length
=
max_cache_hit_length
,
...
...
@@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now."
)
if
max
(
self
.
full_attention_group_ids
)
<
min
(
self
.
other_group_ids
):
self
.
full_attn_first
=
True
elif
max
(
self
.
other_group_ids
)
<
min
(
self
.
full_attention_group_ids
):
self
.
full_attn_first
=
False
else
:
raise
ValueError
(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks."
)
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
)
->
tuple
[
tuple
[
list
[
KVCacheBlock
]
,
...
],
int
]:
"""
Find the longest cache hit for the request.
...
...
@@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
))
hit_length
=
len
(
hit_blocks_other_attn
[
0
])
*
self
.
other_block_size
# NOTE: the prefix cache hit length must be a multipl
y
of block_size as
# NOTE: the prefix cache hit length must be a multipl
e
of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multipl
y
of the block size of
# of other attention is ensured to be a multipl
e
of the block size of
# full attention layers in current implementation, because hit_length is
# a multipl
y
of other attention's block size, and other attention's
# block size is a multipl
y
of full attention's block size (verified in
# a multipl
e
of other attention's block size, and other attention's
# block size is a multipl
e
of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert
hit_length
%
self
.
full_attention_block_size
==
0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for
i
in
range
(
len
(
hit_blocks_full_attn
)):
del
hit_blocks_full_attn
[
i
][
hit_length
//
self
.
full_attention_block_size
:]
for
group_hit_blocks
in
hit_blocks_full_attn
:
del
group_hit_blocks
[
hit_length
//
self
.
full_attention_block_size
:]
# Merge the hit blocks of full attention and other attention.
hit_blocks
=
hit_blocks_other_attn
for
group_id
,
blocks
in
enumerate
(
hit_blocks_full_attn
):
# NOTE: there is only one full attention group in most cases. So
# the time complexity of insert is fine.
hit_blocks
.
insert
(
group_id
,
blocks
)
if
self
.
full_attn_first
:
hit_blocks
=
hit_blocks_full_attn
+
hit_blocks_other_attn
else
:
hit_blocks
=
hit_blocks_other_attn
+
hit_blocks_full_attn
return
hit_blocks
,
hit_length
...
...
@@ -351,8 +358,6 @@ def get_kv_cache_coordinator(
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
else
:
return
HybridKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
return
HybridKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
vllm/v1/core/kv_cache_manager.py
View file @
646d62f6
...
...
@@ -25,7 +25,7 @@ class KVCacheBlocks:
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks
:
list
[
list
[
KVCacheBlock
]]
blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
]
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
...
...
@@ -37,21 +37,19 @@ class KVCacheBlocks:
def
__add__
(
self
,
other
:
"KVCacheBlocks"
)
->
"KVCacheBlocks"
:
"""Adds two KVCacheBlocks instances."""
return
KVCacheBlocks
(
[
blk1
+
blk2
for
blk1
,
blk2
in
zip
(
self
.
blocks
,
other
.
blocks
)])
tuple
(
blk1
+
blk2
for
blk1
,
blk2
in
zip
(
self
.
blocks
,
other
.
blocks
)))
def
get_block_ids
(
self
)
->
list
[
list
[
int
]]:
def
get_block_ids
(
self
)
->
tuple
[
list
[
int
]
,
...
]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
list
[list[int]]: A t
wo-level
list where
* the outer
list
corresponds to KV cache groups
tuple
[list[int]
, ...
]: A t
uple of
list
s
where
* the outer
tuple
corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
"""
block_ids
=
[]
for
group
in
self
.
blocks
:
block_ids
.
append
([
blk
.
block_id
for
blk
in
group
])
return
block_ids
return
tuple
([
blk
.
block_id
for
blk
in
group
]
for
group
in
self
.
blocks
)
def
get_unhashed_block_ids
(
self
)
->
list
[
int
]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
...
...
@@ -63,7 +61,7 @@ class KVCacheBlocks:
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
"""Creates a new KVCacheBlocks instance with no blocks."""
return
KVCacheBlocks
(
[
[]
for
_
in
range
(
len
(
self
.
blocks
))
]
)
return
KVCacheBlocks
(
tuple
(
[]
for
_
in
range
(
len
(
self
.
blocks
))
)
)
class
KVCacheManager
:
...
...
@@ -232,9 +230,8 @@ class KVCacheManager:
if
new_computed_blocks
is
not
None
:
new_computed_block_list
=
new_computed_blocks
.
blocks
else
:
new_computed_block_list
=
[
[]
for
_
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
))
]
new_computed_block_list
=
tuple
(
[]
for
_
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
)))
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
...
...
@@ -267,7 +264,7 @@ class KVCacheManager:
if
self
.
enable_caching
:
self
.
block_pool
.
touch
(
new_computed_block_list
)
else
:
assert
all
(
not
blocks
for
blocks
in
new_computed_block_list
),
(
assert
not
any
(
new_computed_block_list
),
(
"Computed blocks should be empty when "
"prefix caching is disabled"
)
...
...
@@ -378,17 +375,18 @@ class KVCacheManager:
"""
return
self
.
block_pool
.
take_events
()
def
get_block_ids
(
self
,
request_id
:
str
)
->
list
[
list
[
int
]]:
def
get_block_ids
(
self
,
request_id
:
str
)
->
tuple
[
list
[
int
]
,
...
]:
"""Get the block ids of a request."""
return
KVCacheBlocks
(
self
.
coordinator
.
get_blocks
(
request_id
)).
get_block_ids
()
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
num_computed_tokens
:
int
)
->
None
:
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
"""Cache the blocks for the request."""
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
self
.
coordinator
.
cache_blocks
(
request
,
block_hashes
,
num_computed_tokens
)
def
create_empty_block_list
(
self
)
->
KVCacheBlocks
:
"""Creates a new KVCacheBlocks instance with no blocks."""
return
KVCacheBlocks
([[]
for
_
in
range
(
self
.
num_kv_cache_groups
)])
return
KVCacheBlocks
(
tuple
([]
for
_
in
range
(
self
.
num_kv_cache_groups
)))
vllm/v1/core/sched/output.py
View file @
646d62f6
...
...
@@ -27,7 +27,7 @@ class NewRequestData:
mm_hashes
:
list
[
str
]
mm_positions
:
list
[
PlaceholderRange
]
sampling_params
:
SamplingParams
block_ids
:
list
[
list
[
int
]]
block_ids
:
tuple
[
list
[
int
]
,
...
]
num_computed_tokens
:
int
lora_request
:
Optional
[
LoRARequest
]
...
...
@@ -35,7 +35,7 @@ class NewRequestData:
def
from_request
(
cls
,
request
:
Request
,
block_ids
:
list
[
list
[
int
]],
block_ids
:
tuple
[
list
[
int
]
,
...
],
)
->
NewRequestData
:
return
cls
(
req_id
=
request
.
request_id
,
...
...
@@ -86,7 +86,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption
:
bool
new_token_ids
:
list
[
int
]
new_block_ids
:
list
[
list
[
int
]]
new_block_ids
:
tuple
[
list
[
int
]
,
...
]
num_computed_tokens
:
int
@
classmethod
...
...
@@ -95,7 +95,7 @@ class CachedRequestData:
request
:
Request
,
resumed_from_preemption
:
bool
,
new_token_ids
:
list
[
int
],
new_block_ids
:
list
[
list
[
int
]],
new_block_ids
:
tuple
[
list
[
int
]
,
...
],
)
->
CachedRequestData
:
return
cls
(
req_id
=
request
.
request_id
,
...
...
vllm/v1/core/sched/scheduler.py
View file @
646d62f6
...
...
@@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding.
structured_output_request_ids
:
dict
[
str
,
int
]
=
{}
req_to_new_block_ids
:
dict
[
str
,
list
[
list
[
int
]]]
=
{}
req_to_new_block_ids
:
dict
[
str
,
tuple
[
list
[
int
]
,
...
]]
=
{}
num_scheduled_tokens
:
dict
[
str
,
int
]
=
{}
token_budget
=
self
.
max_num_scheduled_tokens
# Encoder-related.
...
...
@@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface):
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
request
.
num_computed_tokens
=
num_computed_tokens
# Count the number of pr
i
fix cached tokens.
# Count the number of pr
e
fix cached tokens.
if
request
.
num_cached_tokens
<
0
:
request
.
num_cached_tokens
=
num_computed_tokens
# Encoder-related.
...
...
@@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface):
request
:
Request
,
num_scheduled_tokens
:
int
,
num_scheduled_spec_tokens
:
int
,
new_block_ids
:
list
[
list
[
int
]],
new_block_ids
:
tuple
[
list
[
int
]
,
...
],
resumed_from_preemption
:
bool
,
)
->
CachedRequestData
:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
...
...
@@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens
=
min
(
num_computed_tokens
,
request
.
num_tokens
)
if
num_computed_tokens
==
request
.
num_tokens
:
num_computed_tokens
-=
1
self
.
kv_cache_manager
.
cache_blocks
(
request
,
self
.
kv_cache_manager
.
req_to_block_hashes
[
request
.
request_id
],
num_computed_tokens
,
)
self
.
kv_cache_manager
.
cache_blocks
(
request
,
num_computed_tokens
)
# Update the request state for scheduling.
request
.
num_computed_tokens
=
num_computed_tokens
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
646d62f6
...
...
@@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. The prefix should be a common prefix hit for all the
...
...
@@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC):
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
[
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]
]
for block size 4
(
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]
)
for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
...
...
@@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager):
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
assert
isinstance
(
kv_cache_spec
,
FullAttentionSpec
),
(
"FullAttentionManager can only be used for full attention groups"
)
computed_blocks
:
list
[
list
[
KVCacheBlock
]]
=
[
[]
for
_
in
range
(
len
(
kv_cache_group_ids
))
]
computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...]
=
tuple
(
[]
for
_
in
range
(
len
(
kv_cache_group_ids
)))
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
for
i
in
range
(
max_num_blocks
):
block_hash
=
block_hashes
[
i
]
for
i
,
block_hash
in
zip
(
range
(
max_num_blocks
),
block_hashes
):
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if
cached_block
:
=
block_pool
.
get_cached_block
(
block_hash
,
kv_cache_group_ids
):
for
j
in
range
(
len
(
kv_cache_group_ids
)
):
computed
_blocks
[
j
]
.
append
(
cached
_block
[
j
]
)
for
computed
,
cached
in
zip
(
computed_blocks
,
cached_block
):
computed
.
append
(
cached
)
else
:
break
if
use_eagle
and
len
(
computed_blocks
[
0
]
)
>
0
:
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
computed
_blocks
[
j
]
.
pop
()
if
use_eagle
and
computed_blocks
[
0
]:
for
computed
in
computed_blocks
:
computed
.
pop
()
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
...
...
@@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
assert
isinstance
(
kv_cache_spec
,
SlidingWindowSpec
),
(
"SlidingWindowManager can only be used for sliding window groups"
)
...
...
@@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
computed_blocks
=
[
[
block_pool
.
null_block
]
*
max_num_blocks
for
_
in
range
(
len
(
kv_cache_group_ids
))
]
computed_blocks
=
tuple
(
[
block_pool
.
null_block
]
*
max_num_blocks
for
_
in
range
(
len
(
kv_cache_group_ids
))
)
num_contiguous_blocks
=
0
match_found
=
False
# Search from right to left and early stop when a match is found.
for
i
in
range
(
max_num_blocks
-
1
,
-
1
,
-
1
):
if
cached_block
:
=
block_pool
.
get_cached_block
(
block_hashes
[
i
],
kv_cache_group_ids
):
for
j
in
range
(
len
(
kv_cache_group_ids
)
):
computed
_blocks
[
j
]
[
i
]
=
cached
_block
[
j
]
for
computed
,
cached
in
zip
(
computed_blocks
,
cached_block
):
computed
[
i
]
=
cached
num_contiguous_blocks
+=
1
if
(
num_contiguous_blocks
>=
sliding_window_contiguous_blocks
)
:
if
num_contiguous_blocks
>=
sliding_window_contiguous_blocks
:
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
del
computed
_blocks
[
j
]
[
i
+
num_contiguous_blocks
:]
for
computed
in
computed_blocks
:
del
computed
[
i
+
num_contiguous_blocks
:]
match_found
=
True
break
else
:
...
...
@@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if
not
match_found
:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
del
computed
_blocks
[
j
]
[
num_contiguous_blocks
:]
if
use_eagle
and
len
(
computed_blocks
[
0
]
)
>
0
:
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
computed
_blocks
[
j
]
.
pop
()
for
computed
in
computed_blocks
:
del
computed
[
num_contiguous_blocks
:]
if
use_eagle
and
computed_blocks
[
0
]:
for
computed
in
computed_blocks
:
computed
.
pop
()
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
...
...
vllm/v1/worker/block_table.py
View file @
646d62f6
...
...
@@ -112,11 +112,12 @@ class MultiGroupBlockTable:
for
block_size
in
block_sizes
]
def
append_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
def
append_row
(
self
,
block_ids
:
tuple
[
list
[
int
],
...],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
append_row
(
block_ids
[
i
],
row_idx
)
def
add_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
def
add_row
(
self
,
block_ids
:
tuple
[
list
[
int
]
,
...
],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
add_row
(
block_ids
[
i
],
row_idx
)
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
646d62f6
...
...
@@ -30,7 +30,7 @@ class CachedRequestState:
sampling_params
:
SamplingParams
generator
:
Optional
[
torch
.
Generator
]
block_ids
:
list
[
list
[
int
]]
block_ids
:
tuple
[
list
[
int
]
,
...
]
num_computed_tokens
:
int
output_token_ids
:
list
[
int
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment