Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
646d62f6
Unverified
Commit
646d62f6
authored
Jun 09, 2025
by
Nick Hill
Committed by
GitHub
Jun 10, 2025
Browse files
[Core] Use tuple for kv cache group block ids (#19175)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
6cd4ae8a
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
140 additions
and
142 deletions
+140
-142
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+22
-22
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+4
-4
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+1
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+2
-2
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+4
-4
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+53
-48
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+18
-20
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+4
-4
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+4
-8
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+24
-26
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+3
-2
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+1
-1
No files found.
tests/v1/core/test_prefix_caching.py
View file @
646d62f6
...
@@ -117,7 +117,7 @@ def test_prefill(hash_algo):
...
@@ -117,7 +117,7 @@ def test_prefill(hash_algo):
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
# Check full block metadata
# Check full block metadata
parent_block_hash
=
None
parent_block_hash
=
None
...
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
...
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
]
,
)
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
assert
blocks
.
get_block_ids
()
==
(
[
5
]
,
)
for
block
in
computed_blocks
.
blocks
[
0
]:
for
block
in
computed_blocks
.
blocks
[
0
]:
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
...
@@ -175,13 +175,13 @@ def test_prefill(hash_algo):
...
@@ -175,13 +175,13 @@ def test_prefill(hash_algo):
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
]
,
)
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
6
]
]
assert
blocks
.
get_block_ids
()
==
(
[
6
]
,
)
# Although we only have 6 free blocks, we have 8 blocks in
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
# the free block queue due to lazy removal.
...
@@ -205,7 +205,7 @@ def test_prefill(hash_algo):
...
@@ -205,7 +205,7 @@ def test_prefill(hash_algo):
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
# This block ID order also checks the eviction order.
# This block ID order also checks the eviction order.
assert
blocks
.
get_block_ids
()
==
[
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
]
assert
blocks
.
get_block_ids
()
==
(
[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]
,
)
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
...
@@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
...
@@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
[
9
,
10
,
11
,
12
]
]
8
],
[
9
,
10
,
11
,
12
]
)
# Check full block metadata
# Check full block metadata
parent_block_hash
=
None
parent_block_hash
=
None
...
@@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
...
@@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
],
[
0
,
6
,
7
],
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
],
[
0
,
6
,
[
0
,
10
,
11
]
]
7
],
[
0
,
10
,
11
]
)
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
13
],
[
14
],
[
15
]
]
assert
blocks
.
get_block_ids
()
==
(
[
13
],
[
14
],
[
15
]
)
for
block_per_group
in
computed_blocks
.
blocks
:
for
block_per_group
in
computed_blocks
.
blocks
:
for
block
in
block_per_group
:
for
block
in
block_per_group
:
if
block
!=
manager
.
block_pool
.
null_block
:
if
block
!=
manager
.
block_pool
.
null_block
:
...
@@ -374,7 +374,7 @@ def test_prefill_plp():
...
@@ -374,7 +374,7 @@ def test_prefill_plp():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]]
# Check full block metadata
# Check full block metadata
...
@@ -400,13 +400,13 @@ def test_prefill_plp():
...
@@ -400,13 +400,13 @@ def test_prefill_plp():
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
]
,
)
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
assert
blocks
.
get_block_ids
()
==
(
[
5
]
,
)
for
block
in
computed_blocks
.
blocks
[
0
]:
for
block
in
computed_blocks
.
blocks
[
0
]:
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
...
@@ -444,7 +444,7 @@ def test_prefill_plp():
...
@@ -444,7 +444,7 @@ def test_prefill_plp():
block_ids
=
blocks
.
get_block_ids
()
block_ids
=
blocks
.
get_block_ids
()
# Duplicate cached blocks have different ids but same hashes vs request #0
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]]
==
req0_block_hashes
assert
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]]
==
req0_block_hashes
assert
block_ids
!=
[
[
1
,
2
,
3
,
4
]
]
assert
block_ids
!=
(
[
1
,
2
,
3
,
4
]
,
)
# Request #2 block hashes are valid since request #0 hashes are.
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
# Check block reference counts.
...
@@ -474,7 +474,7 @@ def test_decode():
...
@@ -474,7 +474,7 @@ def test_decode():
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
# Append slots without allocating a new block.
# Append slots without allocating a new block.
req0
.
num_computed_tokens
=
55
req0
.
num_computed_tokens
=
55
...
@@ -546,12 +546,12 @@ def test_evict():
...
@@ -546,12 +546,12 @@ def test_evict():
# Touch the first 2 blocks.
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
computed_blocks
.
get_block_ids
()
==
[
[
1
,
2
]
]
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
]
,
)
assert
num_computed_tokens
==
2
*
16
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
10
]
]
assert
blocks
.
get_block_ids
()
==
(
[
10
]
,
)
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
...
@@ -865,7 +865,7 @@ def test_mm_prefix_caching():
...
@@ -865,7 +865,7 @@ def test_mm_prefix_caching():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
# Append slots without allocating a new block.
...
@@ -926,7 +926,7 @@ def test_cache_key_salting():
...
@@ -926,7 +926,7 @@ def test_cache_key_salting():
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
# Append slots without allocating a new block.
...
@@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
...
@@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
req0
=
make_request
(
"0"
,
all_token_ids
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
blocks
.
get_block_ids
()
==
[
[
1
,
2
,
3
,
4
]
]
assert
blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
]
,
)
unique_token_ids
=
[
4
]
*
7
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
...
@@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
...
@@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[
[
5
]
]
assert
blocks
.
get_block_ids
()
==
(
[
5
]
,
)
# Failed to reset prefix cache because some blocks are not freed yet.
# Failed to reset prefix cache because some blocks are not freed yet.
assert
not
manager
.
reset_prefix_cache
()
assert
not
manager
.
reset_prefix_cache
()
...
...
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
646d62f6
...
@@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
...
@@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
block_ids
=
[
[
0
]
]
,
# block_ids should be
list
[list[int]]
block_ids
=
(
[
0
]
,
)
,
# block_ids should be
tuple
[list[int]]
num_computed_tokens
=
0
,
num_computed_tokens
=
0
,
lora_request
=
None
,
lora_request
=
None
,
))
))
...
@@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
...
@@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
# This is safe since we currently only use single KV cache groups
# This is safe since we currently only use single KV cache groups
block_table
=
multi_group_block_table
[
0
]
block_table
=
multi_group_block_table
[
0
]
# req_state.block_ids is now
list
[list[int]] for MultiGroupBlockTable
# req_state.block_ids is now
tuple
[list[int]
, ...
] for MultiGroupBlockTable
# Extract the first group's block IDs
# Extract the first group's block IDs
if
isinstance
(
req_state
.
block_ids
[
0
],
list
):
if
isinstance
(
req_state
.
block_ids
[
0
],
list
):
# New format:
list
[list[int]] - extract first group
# New format:
tuple
[list[int]
, ...
] - extract first group
req_block_ids
=
req_state
.
block_ids
[
0
]
req_block_ids
=
req_state
.
block_ids
[
0
]
else
:
else
:
# Legacy format: list[int] - use directly
# Legacy format: list[int] - use directly
...
@@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
...
@@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
req_id
=
req_id
,
req_id
=
req_id
,
resumed_from_preemption
=
False
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_token_ids
=
[],
new_block_ids
=
[
[]
]
,
new_block_ids
=
(
[]
,
)
,
num_computed_tokens
=
0
,
num_computed_tokens
=
0
,
)
)
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
646d62f6
...
@@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
...
@@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params
=
_create_sampling_params
(),
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_inputs
=
[],
mm_positions
=
[],
mm_positions
=
[],
block_ids
=
[
[]
]
,
block_ids
=
(
[]
,
)
,
generator
=
None
,
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
,
output_token_ids
=
output_token_ids
,
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
646d62f6
...
@@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
...
@@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
block_ids
=
[
[
0
]
]
,
block_ids
=
(
[
0
]
,
)
,
num_computed_tokens
=
0
,
num_computed_tokens
=
0
,
lora_request
=
None
,
lora_request
=
None
,
))
))
...
@@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
...
@@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
req_id
=
req_id
,
req_id
=
req_id
,
resumed_from_preemption
=
False
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_token_ids
=
[],
new_block_ids
=
[
[]
]
,
new_block_ids
=
(
[]
,
)
,
num_computed_tokens
=
0
,
num_computed_tokens
=
0
,
)
)
...
...
vllm/v1/core/block_pool.py
View file @
646d62f6
...
@@ -89,8 +89,8 @@ class BlockPool:
...
@@ -89,8 +89,8 @@ class BlockPool:
BlockHashWithGroupId
(
block_hash
,
group_id
))
BlockHashWithGroupId
(
block_hash
,
group_id
))
if
not
cached_blocks_one_group
:
if
not
cached_blocks_one_group
:
return
None
return
None
first_block
_id
=
next
(
iter
(
cached_blocks_one_group
))
first_block
=
next
(
iter
(
cached_blocks_one_group
.
values
()
))
cached_blocks
.
append
(
cached_blocks_one_group
[
first_block
_id
]
)
cached_blocks
.
append
(
first_block
)
return
cached_blocks
return
cached_blocks
def
cache_full_blocks
(
def
cache_full_blocks
(
...
@@ -260,7 +260,7 @@ class BlockPool:
...
@@ -260,7 +260,7 @@ class BlockPool:
return
True
return
True
return
False
return
False
def
touch
(
self
,
blocks
:
list
[
list
[
KVCacheBlock
]])
->
None
:
def
touch
(
self
,
blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
])
->
None
:
"""Touch a block increases its reference count by 1, and may remove
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
another request with the same prefix.
...
@@ -299,7 +299,7 @@ class BlockPool:
...
@@ -299,7 +299,7 @@ class BlockPool:
bool: True if the prefix cache is successfully reset,
bool: True if the prefix cache is successfully reset,
False otherwise.
False otherwise.
"""
"""
num_used_blocks
=
(
self
.
num_gpu_blocks
-
self
.
get_num_free_blocks
()
)
num_used_blocks
=
self
.
num_gpu_blocks
-
self
.
get_num_free_blocks
()
if
num_used_blocks
!=
1
:
# The null block is always marked as used
if
num_used_blocks
!=
1
:
# The null block is always marked as used
logger
.
warning
(
logger
.
warning
(
"Failed to reset prefix cache because some "
"Failed to reset prefix cache because some "
...
...
vllm/v1/core/kv_cache_coordinator.py
View file @
646d62f6
...
@@ -5,8 +5,7 @@ from typing import Callable, Optional
...
@@ -5,8 +5,7 @@ from typing import Callable, Optional
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.single_type_kv_cache_manager
import
(
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
SingleTypeKVCacheManager
,
FullAttentionManager
,
get_manager_for_kv_cache_spec
)
get_manager_for_kv_cache_spec
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
KVCacheConfig
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC):
...
@@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC):
self
.
block_pool
=
BlockPool
(
kv_cache_config
.
num_blocks
,
enable_caching
,
self
.
block_pool
=
BlockPool
(
kv_cache_config
.
num_blocks
,
enable_caching
,
enable_kv_cache_events
)
enable_kv_cache_events
)
self
.
single_type_managers
:
list
[
SingleTypeKVCacheManager
]
=
[]
# Needs special handling for find_longest_cache_hit if eagle is enabled
# Needs special handling for find_longest_cache_hit if eagle is enabled
self
.
use_eagle
=
use_eagle
self
.
use_eagle
=
use_eagle
self
.
single_type_managers
=
tuple
(
for
i
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
)):
get_manager_for_kv_cache_spec
(
kv_cache_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
,
i
].
kv_cache_spec
block_pool
=
self
.
block_pool
,
self
.
single_type_managers
.
append
(
kv_cache_group_id
=
i
,
get_manager_for_kv_cache_spec
(
caching_hash_fn
=
caching_hash_fn
,
kv_cache_spec
=
kv_cache_spec
,
)
for
i
,
kv_cache_group
in
enumerate
(
block_pool
=
self
.
block_pool
,
self
.
kv_cache_config
.
kv_cache_groups
))
kv_cache_group_id
=
i
,
caching_hash_fn
=
caching_hash_fn
,
))
def
get_num_blocks_to_allocate
(
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
self
,
request_id
:
str
,
num_tokens
:
int
,
new_computed_blocks
:
list
[
list
[
KVCacheBlock
]])
->
int
:
new_computed_blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
])
->
int
:
"""
"""
Get the number of blocks needed to be allocated for the request.
Get the number of blocks needed to be allocated for the request.
...
@@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC):
...
@@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC):
def
save_new_computed_blocks
(
def
save_new_computed_blocks
(
self
,
request_id
:
str
,
self
,
request_id
:
str
,
new_computed_blocks
:
list
[
list
[
KVCacheBlock
]])
->
None
:
new_computed_blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
])
->
None
:
"""
"""
Add the new computed blocks to the request.
Add the new computed blocks to the request.
...
@@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC):
...
@@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC):
new_computed_blocks
[
i
])
new_computed_blocks
[
i
])
def
allocate_new_blocks
(
self
,
request_id
:
str
,
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
list
[
list
[
KVCacheBlock
]]:
num_tokens
:
int
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
"""
"""
Allocate new blocks for the request to give it at least `num_tokens`
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
token slots.
...
@@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC):
...
@@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC):
Returns:
Returns:
The new allocated blocks.
The new allocated blocks.
"""
"""
new_blocks
=
[]
return
tuple
(
for
manager
in
self
.
single_type_managers
:
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
)
new_blocks
.
append
(
for
manager
in
self
.
single_type_managers
)
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
))
return
new_blocks
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
num_computed_tokens
:
int
)
->
None
:
num_computed_tokens
:
int
)
->
None
:
...
@@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC):
...
@@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC):
for
manager
in
self
.
single_type_managers
:
for
manager
in
self
.
single_type_managers
:
manager
.
remove_skipped_blocks
(
request_id
,
num_computed_tokens
)
manager
.
remove_skipped_blocks
(
request_id
,
num_computed_tokens
)
def
get_blocks
(
self
,
request_id
:
str
)
->
list
[
list
[
KVCacheBlock
]]:
def
get_blocks
(
self
,
request_id
:
str
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
"""
"""
Get the blocks for the request.
Get the blocks for the request.
"""
"""
return
[
return
tuple
(
manager
.
req_to_blocks
.
get
(
request_id
)
or
[]
manager
.
req_to_blocks
.
get
(
request_id
)
or
[]
for
manager
in
self
.
single_type_managers
for
manager
in
self
.
single_type_managers
)
]
@
abstractmethod
@
abstractmethod
def
find_longest_cache_hit
(
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
self
,
max_cache_hit_length
:
int
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
)
->
tuple
[
tuple
[
list
[
KVCacheBlock
],
...],
int
]:
pass
pass
...
@@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
...
@@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
def
find_longest_cache_hit
(
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
self
,
max_cache_hit_length
:
int
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
)
->
tuple
[
tuple
[
list
[
KVCacheBlock
],
...],
int
]:
hit_blocks
=
self
.
single_type_managers
[
0
].
find_longest_cache_hit
(
hit_blocks
=
self
.
single_type_managers
[
0
].
find_longest_cache_hit
(
block_hashes
=
block_hashes
,
block_hashes
=
block_hashes
,
max_length
=
max_cache_hit_length
,
max_length
=
max_cache_hit_length
,
...
@@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
...
@@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
"KVCacheCoordinator assumes the block_size of full attention "
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now."
)
"layers is divisible by other layers now."
)
if
max
(
self
.
full_attention_group_ids
)
<
min
(
self
.
other_group_ids
):
self
.
full_attn_first
=
True
elif
max
(
self
.
other_group_ids
)
<
min
(
self
.
full_attention_group_ids
):
self
.
full_attn_first
=
False
else
:
raise
ValueError
(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks."
)
def
find_longest_cache_hit
(
def
find_longest_cache_hit
(
self
,
self
,
block_hashes
:
list
[
BlockHash
],
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
max_cache_hit_length
:
int
,
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
)
->
tuple
[
tuple
[
list
[
KVCacheBlock
]
,
...
],
int
]:
"""
"""
Find the longest cache hit for the request.
Find the longest cache hit for the request.
...
@@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
...
@@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
))
))
hit_length
=
len
(
hit_blocks_other_attn
[
0
])
*
self
.
other_block_size
hit_length
=
len
(
hit_blocks_other_attn
[
0
])
*
self
.
other_block_size
# NOTE: the prefix cache hit length must be a multipl
y
of block_size as
# NOTE: the prefix cache hit length must be a multipl
e
of block_size as
# we don't support partial block cache hit yet. The cache hit length
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multipl
y
of the block size of
# of other attention is ensured to be a multipl
e
of the block size of
# full attention layers in current implementation, because hit_length is
# full attention layers in current implementation, because hit_length is
# a multipl
y
of other attention's block size, and other attention's
# a multipl
e
of other attention's block size, and other attention's
# block size is a multipl
y
of full attention's block size (verified in
# block size is a multipl
e
of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
# `verify_and_split_kv_cache_groups`).
assert
hit_length
%
self
.
full_attention_block_size
==
0
assert
hit_length
%
self
.
full_attention_block_size
==
0
# Truncate the full attention cache hit to the length of the
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
# cache hit of the other attention.
for
i
in
range
(
len
(
hit_blocks_full_attn
)):
for
group_hit_blocks
in
hit_blocks_full_attn
:
del
hit_blocks_full_attn
[
i
][
hit_length
//
del
group_hit_blocks
[
hit_length
//
self
.
full_attention_block_size
:]
self
.
full_attention_block_size
:]
# Merge the hit blocks of full attention and other attention.
# Merge the hit blocks of full attention and other attention.
hit_blocks
=
hit_blocks_other_attn
if
self
.
full_attn_first
:
for
group_id
,
blocks
in
enumerate
(
hit_blocks_full_attn
):
hit_blocks
=
hit_blocks_full_attn
+
hit_blocks_other_attn
# NOTE: there is only one full attention group in most cases. So
else
:
# the time complexity of insert is fine.
hit_blocks
=
hit_blocks_other_attn
+
hit_blocks_full_attn
hit_blocks
.
insert
(
group_id
,
blocks
)
return
hit_blocks
,
hit_length
return
hit_blocks
,
hit_length
...
@@ -351,8 +358,6 @@ def get_kv_cache_coordinator(
...
@@ -351,8 +358,6 @@ def get_kv_cache_coordinator(
use_eagle
,
enable_caching
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
caching_hash_fn
,
enable_kv_cache_events
)
enable_kv_cache_events
)
else
:
return
HybridKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
return
HybridKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
enable_caching
,
caching_hash_fn
,
use_eagle
,
enable_caching
,
enable_kv_cache_events
)
caching_hash_fn
,
enable_kv_cache_events
)
vllm/v1/core/kv_cache_manager.py
View file @
646d62f6
...
@@ -21,11 +21,11 @@ logger = init_logger(__name__)
...
@@ -21,11 +21,11 @@ logger = init_logger(__name__)
@
dataclass
@
dataclass
class
KVCacheBlocks
:
class
KVCacheBlocks
:
"""
"""
The allocation result of KVCacheManager, work as the interface between
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
structure from the Scheduler.
"""
"""
blocks
:
list
[
list
[
KVCacheBlock
]]
blocks
:
tuple
[
list
[
KVCacheBlock
]
,
...
]
"""
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
We don't use block of tokens as the outer dimension because it assumes all
...
@@ -37,21 +37,19 @@ class KVCacheBlocks:
...
@@ -37,21 +37,19 @@ class KVCacheBlocks:
def
__add__
(
self
,
other
:
"KVCacheBlocks"
)
->
"KVCacheBlocks"
:
def
__add__
(
self
,
other
:
"KVCacheBlocks"
)
->
"KVCacheBlocks"
:
"""Adds two KVCacheBlocks instances."""
"""Adds two KVCacheBlocks instances."""
return
KVCacheBlocks
(
return
KVCacheBlocks
(
[
blk1
+
blk2
for
blk1
,
blk2
in
zip
(
self
.
blocks
,
other
.
blocks
)])
tuple
(
blk1
+
blk2
for
blk1
,
blk2
in
zip
(
self
.
blocks
,
other
.
blocks
)))
def
get_block_ids
(
self
)
->
list
[
list
[
int
]]:
def
get_block_ids
(
self
)
->
tuple
[
list
[
int
]
,
...
]:
"""
"""
Converts the KVCacheBlocks instance to block_ids.
Converts the KVCacheBlocks instance to block_ids.
Returns:
Returns:
list
[list[int]]: A t
wo-level
list where
tuple
[list[int]
, ...
]: A t
uple of
list
s
where
* the outer
list
corresponds to KV cache groups
* the outer
tuple
corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
* each inner list contains the block_ids of the blocks in that group
"""
"""
block_ids
=
[]
return
tuple
([
blk
.
block_id
for
blk
in
group
]
for
group
in
self
.
blocks
)
for
group
in
self
.
blocks
:
block_ids
.
append
([
blk
.
block_id
for
blk
in
group
])
return
block_ids
def
get_unhashed_block_ids
(
self
)
->
list
[
int
]:
def
get_unhashed_block_ids
(
self
)
->
list
[
int
]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
...
@@ -63,7 +61,7 @@ class KVCacheBlocks:
...
@@ -63,7 +61,7 @@ class KVCacheBlocks:
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
def
new_empty
(
self
)
->
"KVCacheBlocks"
:
"""Creates a new KVCacheBlocks instance with no blocks."""
"""Creates a new KVCacheBlocks instance with no blocks."""
return
KVCacheBlocks
(
[
[]
for
_
in
range
(
len
(
self
.
blocks
))
]
)
return
KVCacheBlocks
(
tuple
(
[]
for
_
in
range
(
len
(
self
.
blocks
))
)
)
class
KVCacheManager
:
class
KVCacheManager
:
...
@@ -232,9 +230,8 @@ class KVCacheManager:
...
@@ -232,9 +230,8 @@ class KVCacheManager:
if
new_computed_blocks
is
not
None
:
if
new_computed_blocks
is
not
None
:
new_computed_block_list
=
new_computed_blocks
.
blocks
new_computed_block_list
=
new_computed_blocks
.
blocks
else
:
else
:
new_computed_block_list
=
[
new_computed_block_list
=
tuple
(
[]
for
_
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
))
[]
for
_
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
)))
]
# Free the blocks that are skipped during the attention computation
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# (e.g., tokens outside the sliding window).
...
@@ -267,7 +264,7 @@ class KVCacheManager:
...
@@ -267,7 +264,7 @@ class KVCacheManager:
if
self
.
enable_caching
:
if
self
.
enable_caching
:
self
.
block_pool
.
touch
(
new_computed_block_list
)
self
.
block_pool
.
touch
(
new_computed_block_list
)
else
:
else
:
assert
all
(
not
blocks
for
blocks
in
new_computed_block_list
),
(
assert
not
any
(
new_computed_block_list
),
(
"Computed blocks should be empty when "
"Computed blocks should be empty when "
"prefix caching is disabled"
)
"prefix caching is disabled"
)
...
@@ -378,17 +375,18 @@ class KVCacheManager:
...
@@ -378,17 +375,18 @@ class KVCacheManager:
"""
"""
return
self
.
block_pool
.
take_events
()
return
self
.
block_pool
.
take_events
()
def
get_block_ids
(
self
,
request_id
:
str
)
->
list
[
list
[
int
]]:
def
get_block_ids
(
self
,
request_id
:
str
)
->
tuple
[
list
[
int
]
,
...
]:
"""Get the block ids of a request."""
"""Get the block ids of a request."""
return
KVCacheBlocks
(
return
KVCacheBlocks
(
self
.
coordinator
.
get_blocks
(
request_id
)).
get_block_ids
()
self
.
coordinator
.
get_blocks
(
request_id
)).
get_block_ids
()
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
num_computed_tokens
:
int
)
->
None
:
"""Cache the blocks for the request."""
"""Cache the blocks for the request."""
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
self
.
coordinator
.
cache_blocks
(
request
,
block_hashes
,
self
.
coordinator
.
cache_blocks
(
request
,
block_hashes
,
num_computed_tokens
)
num_computed_tokens
)
def
create_empty_block_list
(
self
)
->
KVCacheBlocks
:
def
create_empty_block_list
(
self
)
->
KVCacheBlocks
:
"""Creates a new KVCacheBlocks instance with no blocks."""
"""Creates a new KVCacheBlocks instance with no blocks."""
return
KVCacheBlocks
([[]
for
_
in
range
(
self
.
num_kv_cache_groups
)])
return
KVCacheBlocks
(
tuple
([]
for
_
in
range
(
self
.
num_kv_cache_groups
)))
vllm/v1/core/sched/output.py
View file @
646d62f6
...
@@ -27,7 +27,7 @@ class NewRequestData:
...
@@ -27,7 +27,7 @@ class NewRequestData:
mm_hashes
:
list
[
str
]
mm_hashes
:
list
[
str
]
mm_positions
:
list
[
PlaceholderRange
]
mm_positions
:
list
[
PlaceholderRange
]
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
block_ids
:
list
[
list
[
int
]]
block_ids
:
tuple
[
list
[
int
]
,
...
]
num_computed_tokens
:
int
num_computed_tokens
:
int
lora_request
:
Optional
[
LoRARequest
]
lora_request
:
Optional
[
LoRARequest
]
...
@@ -35,7 +35,7 @@ class NewRequestData:
...
@@ -35,7 +35,7 @@ class NewRequestData:
def
from_request
(
def
from_request
(
cls
,
cls
,
request
:
Request
,
request
:
Request
,
block_ids
:
list
[
list
[
int
]],
block_ids
:
tuple
[
list
[
int
]
,
...
],
)
->
NewRequestData
:
)
->
NewRequestData
:
return
cls
(
return
cls
(
req_id
=
request
.
request_id
,
req_id
=
request
.
request_id
,
...
@@ -86,7 +86,7 @@ class CachedRequestData:
...
@@ -86,7 +86,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs.
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption
:
bool
resumed_from_preemption
:
bool
new_token_ids
:
list
[
int
]
new_token_ids
:
list
[
int
]
new_block_ids
:
list
[
list
[
int
]]
new_block_ids
:
tuple
[
list
[
int
]
,
...
]
num_computed_tokens
:
int
num_computed_tokens
:
int
@
classmethod
@
classmethod
...
@@ -95,7 +95,7 @@ class CachedRequestData:
...
@@ -95,7 +95,7 @@ class CachedRequestData:
request
:
Request
,
request
:
Request
,
resumed_from_preemption
:
bool
,
resumed_from_preemption
:
bool
,
new_token_ids
:
list
[
int
],
new_token_ids
:
list
[
int
],
new_block_ids
:
list
[
list
[
int
]],
new_block_ids
:
tuple
[
list
[
int
]
,
...
],
)
->
CachedRequestData
:
)
->
CachedRequestData
:
return
cls
(
return
cls
(
req_id
=
request
.
request_id
,
req_id
=
request
.
request_id
,
...
...
vllm/v1/core/sched/scheduler.py
View file @
646d62f6
...
@@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
...
@@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding.
# uses structured decoding.
structured_output_request_ids
:
dict
[
str
,
int
]
=
{}
structured_output_request_ids
:
dict
[
str
,
int
]
=
{}
req_to_new_block_ids
:
dict
[
str
,
list
[
list
[
int
]]]
=
{}
req_to_new_block_ids
:
dict
[
str
,
tuple
[
list
[
int
]
,
...
]]
=
{}
num_scheduled_tokens
:
dict
[
str
,
int
]
=
{}
num_scheduled_tokens
:
dict
[
str
,
int
]
=
{}
token_budget
=
self
.
max_num_scheduled_tokens
token_budget
=
self
.
max_num_scheduled_tokens
# Encoder-related.
# Encoder-related.
...
@@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface):
...
@@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface):
token_budget
-=
num_new_tokens
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
request
.
status
=
RequestStatus
.
RUNNING
request
.
num_computed_tokens
=
num_computed_tokens
request
.
num_computed_tokens
=
num_computed_tokens
# Count the number of pr
i
fix cached tokens.
# Count the number of pr
e
fix cached tokens.
if
request
.
num_cached_tokens
<
0
:
if
request
.
num_cached_tokens
<
0
:
request
.
num_cached_tokens
=
num_computed_tokens
request
.
num_cached_tokens
=
num_computed_tokens
# Encoder-related.
# Encoder-related.
...
@@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface):
...
@@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface):
request
:
Request
,
request
:
Request
,
num_scheduled_tokens
:
int
,
num_scheduled_tokens
:
int
,
num_scheduled_spec_tokens
:
int
,
num_scheduled_spec_tokens
:
int
,
new_block_ids
:
list
[
list
[
int
]],
new_block_ids
:
tuple
[
list
[
int
]
,
...
],
resumed_from_preemption
:
bool
,
resumed_from_preemption
:
bool
,
)
->
CachedRequestData
:
)
->
CachedRequestData
:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
...
@@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens
=
min
(
num_computed_tokens
,
request
.
num_tokens
)
num_computed_tokens
=
min
(
num_computed_tokens
,
request
.
num_tokens
)
if
num_computed_tokens
==
request
.
num_tokens
:
if
num_computed_tokens
==
request
.
num_tokens
:
num_computed_tokens
-=
1
num_computed_tokens
-=
1
self
.
kv_cache_manager
.
cache_blocks
(
self
.
kv_cache_manager
.
cache_blocks
(
request
,
num_computed_tokens
)
request
,
self
.
kv_cache_manager
.
req_to_block_hashes
[
request
.
request_id
],
num_computed_tokens
,
)
# Update the request state for scheduling.
# Update the request state for scheduling.
request
.
num_computed_tokens
=
num_computed_tokens
request
.
num_computed_tokens
=
num_computed_tokens
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
646d62f6
...
@@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool
:
BlockPool
,
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
"""
"""
Get the longest cache hit prefix of the blocks that is not longer than
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. The prefix should be a common prefix hit for all the
`max_length`. The prefix should be a common prefix hit for all the
...
@@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC):
element is a list of cached blocks for the i-th kv cache group
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
For example, sliding window manager should return a list like
[
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]
]
for block size 4
(
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]
)
for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
"""
...
@@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager):
...
@@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager):
block_pool
:
BlockPool
,
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
assert
isinstance
(
kv_cache_spec
,
FullAttentionSpec
),
(
assert
isinstance
(
kv_cache_spec
,
FullAttentionSpec
),
(
"FullAttentionManager can only be used for full attention groups"
)
"FullAttentionManager can only be used for full attention groups"
)
computed_blocks
:
list
[
list
[
KVCacheBlock
]]
=
[
computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...]
=
tuple
(
[]
for
_
in
range
(
len
(
kv_cache_group_ids
))
[]
for
_
in
range
(
len
(
kv_cache_group_ids
)))
]
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
for
i
in
range
(
max_num_blocks
):
for
i
,
block_hash
in
zip
(
range
(
max_num_blocks
),
block_hashes
):
block_hash
=
block_hashes
[
i
]
# block_hashes is a chain of block hashes. If a block hash is not
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
# not computed yet for sure.
if
cached_block
:
=
block_pool
.
get_cached_block
(
if
cached_block
:
=
block_pool
.
get_cached_block
(
block_hash
,
kv_cache_group_ids
):
block_hash
,
kv_cache_group_ids
):
for
j
in
range
(
len
(
kv_cache_group_ids
)
):
for
computed
,
cached
in
zip
(
computed_blocks
,
cached_block
):
computed
_blocks
[
j
]
.
append
(
cached
_block
[
j
]
)
computed
.
append
(
cached
)
else
:
else
:
break
break
if
use_eagle
and
len
(
computed_blocks
[
0
]
)
>
0
:
if
use_eagle
and
computed_blocks
[
0
]:
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
for
computed
in
computed_blocks
:
computed
_blocks
[
j
]
.
pop
()
computed
.
pop
()
return
computed_blocks
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
...
@@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
...
@@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
block_pool
:
BlockPool
,
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
)
->
tuple
[
list
[
KVCacheBlock
]
,
...
]:
assert
isinstance
(
kv_cache_spec
,
SlidingWindowSpec
),
(
assert
isinstance
(
kv_cache_spec
,
SlidingWindowSpec
),
(
"SlidingWindowManager can only be used for sliding window groups"
)
"SlidingWindowManager can only be used for sliding window groups"
)
...
@@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
...
@@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# sliding_window_contiguous_blocks),
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
# which is good for low cache hit rate scenarios.
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
computed_blocks
=
[
[
block_pool
.
null_block
]
*
max_num_blocks
computed_blocks
=
tuple
(
[
block_pool
.
null_block
]
*
max_num_blocks
for
_
in
range
(
len
(
kv_cache_group_ids
))
]
for
_
in
range
(
len
(
kv_cache_group_ids
))
)
num_contiguous_blocks
=
0
num_contiguous_blocks
=
0
match_found
=
False
match_found
=
False
# Search from right to left and early stop when a match is found.
# Search from right to left and early stop when a match is found.
for
i
in
range
(
max_num_blocks
-
1
,
-
1
,
-
1
):
for
i
in
range
(
max_num_blocks
-
1
,
-
1
,
-
1
):
if
cached_block
:
=
block_pool
.
get_cached_block
(
if
cached_block
:
=
block_pool
.
get_cached_block
(
block_hashes
[
i
],
kv_cache_group_ids
):
block_hashes
[
i
],
kv_cache_group_ids
):
for
j
in
range
(
len
(
kv_cache_group_ids
)
):
for
computed
,
cached
in
zip
(
computed_blocks
,
cached_block
):
computed
_blocks
[
j
]
[
i
]
=
cached
_block
[
j
]
computed
[
i
]
=
cached
num_contiguous_blocks
+=
1
num_contiguous_blocks
+=
1
if
(
num_contiguous_blocks
>=
sliding_window_contiguous_blocks
)
:
if
num_contiguous_blocks
>=
sliding_window_contiguous_blocks
:
# Trim the trailing blocks.
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
# when sliding_window_contiguous_blocks=2.
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
for
computed
in
computed_blocks
:
del
computed
_blocks
[
j
]
[
i
+
num_contiguous_blocks
:]
del
computed
[
i
+
num_contiguous_blocks
:]
match_found
=
True
match_found
=
True
break
break
else
:
else
:
...
@@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
...
@@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if
not
match_found
:
if
not
match_found
:
# The first `num_contiguous_blocks` is a cache hit even if
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
for
computed
in
computed_blocks
:
del
computed
_blocks
[
j
]
[
num_contiguous_blocks
:]
del
computed
[
num_contiguous_blocks
:]
if
use_eagle
and
len
(
computed_blocks
[
0
]
)
>
0
:
if
use_eagle
and
computed_blocks
[
0
]:
for
j
in
range
(
len
(
kv_cache_group_ids
))
:
for
computed
in
computed_blocks
:
computed
_blocks
[
j
]
.
pop
()
computed
.
pop
()
return
computed_blocks
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
...
...
vllm/v1/worker/block_table.py
View file @
646d62f6
...
@@ -112,11 +112,12 @@ class MultiGroupBlockTable:
...
@@ -112,11 +112,12 @@ class MultiGroupBlockTable:
for
block_size
in
block_sizes
for
block_size
in
block_sizes
]
]
def
append_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
def
append_row
(
self
,
block_ids
:
tuple
[
list
[
int
],
...],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
append_row
(
block_ids
[
i
],
row_idx
)
block_table
.
append_row
(
block_ids
[
i
],
row_idx
)
def
add_row
(
self
,
block_ids
:
list
[
list
[
int
]],
row_idx
:
int
)
->
None
:
def
add_row
(
self
,
block_ids
:
tuple
[
list
[
int
]
,
...
],
row_idx
:
int
)
->
None
:
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
block_table
.
add_row
(
block_ids
[
i
],
row_idx
)
block_table
.
add_row
(
block_ids
[
i
],
row_idx
)
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
646d62f6
...
@@ -30,7 +30,7 @@ class CachedRequestState:
...
@@ -30,7 +30,7 @@ class CachedRequestState:
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
generator
:
Optional
[
torch
.
Generator
]
generator
:
Optional
[
torch
.
Generator
]
block_ids
:
list
[
list
[
int
]]
block_ids
:
tuple
[
list
[
int
]
,
...
]
num_computed_tokens
:
int
num_computed_tokens
:
int
output_token_ids
:
list
[
int
]
output_token_ids
:
list
[
int
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment