Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
994fc655
Unverified
Commit
994fc655
authored
Jan 15, 2025
by
Chen Zhang
Committed by
GitHub
Jan 15, 2025
Browse files
[V1][Prefix Cache] Move the logic of num_computed_tokens into KVCacheManager (#12003)
parent
3f9b7ab9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
35 deletions
+61
-35
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+47
-24
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+12
-5
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+2
-6
No files found.
tests/v1/core/test_prefix_caching.py
View file @
994fc655
...
@@ -49,9 +49,10 @@ def test_prefill():
...
@@ -49,9 +49,10 @@ def test_prefill():
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
req0
=
make_request
(
"0"
,
all_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
...
@@ -73,9 +74,10 @@ def test_prefill():
...
@@ -73,9 +74,10 @@ def test_prefill():
# Incomplete 1 block (5 tokens)
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
kv_block_hashes
)
==
3
assert
len
(
req1
.
kv_block_hashes
)
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
,
6
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
,
6
]
...
@@ -91,7 +93,7 @@ def test_prefill():
...
@@ -91,7 +93,7 @@ def test_prefill():
# All blocks should be available.
# All blocks should be available.
assert
manager
.
free_block_queue
.
num_free_blocks
==
10
assert
manager
.
free_block_queue
.
num_free_blocks
==
10
# The order should be
# The order should be
# [unallocated (7, 8)]
# [unallocated (7, 8
, 9
)]
# [unique_req0 (4, 3)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
# [common (2, 1, 0)]
...
@@ -103,9 +105,10 @@ def test_prefill():
...
@@ -103,9 +105,10 @@ def test_prefill():
# Incomplete 1 block (6 tokens)
# Incomplete 1 block (6 tokens)
unique_token_ids
=
[
3
]
*
6
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
req2
.
kv_block_hashes
)
==
3
assert
len
(
req2
.
kv_block_hashes
)
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
7
,
8
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
7
,
8
]
...
@@ -123,8 +126,9 @@ def test_prefill():
...
@@ -123,8 +126,9 @@ def test_prefill():
# Cache miss and eviction.
# Cache miss and eviction.
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
9
))
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
9
))
computed_blocks
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req3
,
16
*
9
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req3
,
16
*
9
,
computed_blocks
)
# This block ID order also checks the eviction order.
# This block ID order also checks the eviction order.
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
9
,
4
,
3
,
6
,
5
,
8
,
7
,
2
,
1
,
0
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
9
,
4
,
3
,
6
,
5
,
8
,
7
,
2
,
1
,
0
]
...
@@ -150,8 +154,9 @@ def test_decode():
...
@@ -150,8 +154,9 @@ def test_decode():
# Incomplete 1 block (7 tokens)
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
)
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
...
@@ -197,16 +202,18 @@ def test_evict():
...
@@ -197,16 +202,18 @@ def test_evict():
last_token_id
=
5
*
16
+
7
last_token_id
=
5
*
16
+
7
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)))
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
5
*
16
+
7
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
5
*
16
+
7
,
computed_blocks
)
assert
len
(
blocks
)
==
7
# 5 full + 1 partial + 1 preallocated
assert
len
(
blocks
)
==
7
# 5 full + 1 partial + 1 preallocated
# 3 blocks.
# 3 blocks.
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
last_token_id
+
3
*
16
)))
last_token_id
+
3
*
16
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
3
*
16
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req1
,
3
*
16
,
computed_blocks
)
assert
len
(
blocks
)
==
3
# 3 full blocks
assert
len
(
blocks
)
==
3
# 3 full blocks
last_token_id
+=
3
*
16
last_token_id
+=
3
*
16
...
@@ -222,8 +229,9 @@ def test_evict():
...
@@ -222,8 +229,9 @@ def test_evict():
# Touch the first 2 blocks.
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
]
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
,
5
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
,
5
]
assert
manager
.
free_block_queue
.
num_free_blocks
==
6
assert
manager
.
free_block_queue
.
num_free_blocks
==
6
...
@@ -247,8 +255,9 @@ def test_hash_block_correct_reuse():
...
@@ -247,8 +255,9 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
# Allocate 1 block and cache it.
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
,
computed_blocks
)
assert
len
(
blocks
)
==
1
assert
len
(
blocks
)
==
1
...
@@ -258,8 +267,9 @@ def test_hash_block_correct_reuse():
...
@@ -258,8 +267,9 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
# block is cleared.
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)))
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
-
1
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
-
1
,
computed_blocks
)
assert
len
(
blocks
)
==
1
assert
len
(
blocks
)
==
1
...
@@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted():
...
@@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
# Allocate a block and cache it.
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
num_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
num_tokens
,
computed_blocks
)
assert
len
(
blocks
)
==
1
assert
len
(
blocks
)
==
1
assert
blocks
[
0
].
block_id
==
0
assert
blocks
[
0
].
block_id
==
0
# Allocate another block.
# Allocate another block.
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)))
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
num_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req1
,
num_tokens
,
computed_blocks
)
assert
len
(
blocks
)
==
1
assert
len
(
blocks
)
==
1
assert
blocks
[
0
].
block_id
==
1
assert
blocks
[
0
].
block_id
==
1
...
@@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted():
...
@@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
# cached block rather than the first one.
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)))
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
)
==
1
assert
len
(
computed_blocks
)
==
1
assert
computed_blocks
[
0
].
block_id
==
0
assert
computed_blocks
[
0
].
block_id
==
0
assert
num_computed_tokens
==
block_size
blocks
=
manager
.
allocate_slots
(
req2
,
num_tokens
*
2
-
num_tokens
,
blocks
=
manager
.
allocate_slots
(
req2
,
num_tokens
*
2
-
num_tokens
,
computed_blocks
)
computed_blocks
)
...
@@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled():
...
@@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled():
req1
=
make_request
(
"1"
,
list
(
range
(
10
)))
# 2 blocks and some more
req1
=
make_request
(
"1"
,
list
(
range
(
10
)))
# 2 blocks and some more
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
10
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req1
,
10
,
computed_blocks
)
assert
len
(
blocks
)
==
3
assert
len
(
blocks
)
==
3
...
@@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled():
...
@@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled():
# No caching.
# No caching.
req2
=
make_request
(
"2"
,
list
(
range
(
16
)))
# shared prefix
req2
=
make_request
(
"2"
,
list
(
range
(
16
)))
# shared prefix
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req2
,
16
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req2
,
16
,
computed_blocks
)
assert
len
(
blocks
)
==
4
assert
len
(
blocks
)
==
4
# New requests should not have any blocks.
# New requests should not have any blocks.
req3
=
make_request
(
"3"
,
list
(
range
(
4
)))
req3
=
make_request
(
"3"
,
list
(
range
(
4
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req3
,
4
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req3
,
4
,
computed_blocks
)
assert
not
blocks
assert
not
blocks
...
@@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
...
@@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
num_preallocated_blocks
=
cdiv
(
num_preallocate_tokens
,
block_size
)
num_preallocated_blocks
=
cdiv
(
num_preallocate_tokens
,
block_size
)
req
=
make_request
(
"0"
,
list
(
range
(
block_size
*
30
)))
req
=
make_request
(
"0"
,
list
(
range
(
block_size
*
30
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
# Just ask for 1 block.
# Just ask for 1 block.
blocks
=
manager
.
allocate_slots
(
req
,
block_size
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req
,
block_size
,
computed_blocks
)
req
.
num_computed_tokens
=
block_size
req
.
num_computed_tokens
=
block_size
...
@@ -469,10 +486,11 @@ def test_mm_prefix_caching():
...
@@ -469,10 +486,11 @@ def test_mm_prefix_caching():
all_token_ids
,
all_token_ids
,
mm_positions
=
mm_positions
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
mm_hashes
=
mm_hashes
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes with extra keys.
# Completed block should have hashes with extra keys.
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
req0
.
kv_block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
req0
.
kv_block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
req0
.
kv_block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
assert
req0
.
kv_block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
...
@@ -503,8 +521,9 @@ def test_mm_prefix_caching():
...
@@ -503,8 +521,9 @@ def test_mm_prefix_caching():
all_token_ids
,
all_token_ids
,
mm_positions
=
mm_positions
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
mm_hashes
=
mm_hashes
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
computed_blocks
)
==
3
assert
len
(
computed_blocks
)
==
3
assert
num_computed_tokens
==
3
*
16
def
test_prefill_not_enough_free_blocks_with_computed_blocks
():
def
test_prefill_not_enough_free_blocks_with_computed_blocks
():
...
@@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | ... |
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
req0
=
make_request
(
"0"
,
common_token_ids
)
req0
=
make_request
(
"0"
,
common_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
manager
.
allocate_slots
(
req0
,
48
,
computed_blocks
)
manager
.
allocate_slots
(
req0
,
48
,
computed_blocks
)
block_part0
=
manager
.
req_to_blocks
[
req0
.
request_id
]
block_part0
=
manager
.
req_to_blocks
[
req0
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1
=
make_request
(
"1"
,
common_token_ids
*
2
)
req1
=
make_request
(
"1"
,
common_token_ids
*
2
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
computed_blocks
==
block_part0
assert
computed_blocks
==
block_part0
assert
num_computed_tokens
==
3
*
16
manager
.
allocate_slots
(
req1
,
48
,
computed_blocks
)
manager
.
allocate_slots
(
req1
,
48
,
computed_blocks
)
block_part1
=
manager
.
req_to_blocks
[
req1
.
request_id
]
block_part1
=
manager
.
req_to_blocks
[
req1
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
...
@@ -547,8 +568,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -547,8 +568,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
)
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
)
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
manager
.
allocate_slots
(
req2
,
block_size
*
2
,
computed_blocks
)
manager
.
allocate_slots
(
req2
,
block_size
*
2
,
computed_blocks
)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
...
@@ -556,8 +578,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -556,8 +578,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# In this case, the ref_cnt of the computed blocks should not be changed.
# In this case, the ref_cnt of the computed blocks should not be changed.
assert
manager
.
free_block_queue
.
num_free_blocks
==
5
assert
manager
.
free_block_queue
.
num_free_blocks
==
5
req3
=
make_request
(
"3"
,
common_token_ids
*
3
)
req3
=
make_request
(
"3"
,
common_token_ids
*
3
)
computed_blocks
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
computed_blocks
==
block_part1
assert
computed_blocks
==
block_part1
assert
num_computed_tokens
==
6
*
16
# Req3 cannot be allocated.
# Req3 cannot be allocated.
assert
manager
.
allocate_slots
(
req3
,
48
,
computed_blocks
)
is
None
assert
manager
.
allocate_slots
(
req3
,
48
,
computed_blocks
)
is
None
# Block 0-2 are used by Req 1.
# Block 0-2 are used by Req 1.
...
...
vllm/v1/core/kv_cache_manager.py
View file @
994fc655
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
...
@@ -69,7 +69,8 @@ class KVCacheManager:
...
@@ -69,7 +69,8 @@ class KVCacheManager:
# is finished.
# is finished.
self
.
req_to_blocks
:
Dict
[
str
,
List
[
KVCacheBlock
]]
=
{}
self
.
req_to_blocks
:
Dict
[
str
,
List
[
KVCacheBlock
]]
=
{}
def
get_computed_blocks
(
self
,
request
:
Request
)
->
List
[
KVCacheBlock
]:
def
get_computed_blocks
(
self
,
request
:
Request
)
->
Tuple
[
List
[
KVCacheBlock
],
int
]:
"""Get the computed (cached) blocks for the request.
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
Note that the computed blocks must be full.
...
@@ -77,11 +78,13 @@ class KVCacheManager:
...
@@ -77,11 +78,13 @@ class KVCacheManager:
request: The request to get the computed blocks.
request: The request to get the computed blocks.
Returns:
Returns:
A list of blocks that are computed for the request.
A tuple containing:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
"""
if
not
self
.
enable_caching
:
if
not
self
.
enable_caching
:
# Prefix caching is disabled.
# Prefix caching is disabled.
return
[]
return
[]
,
0
computed_blocks
=
[]
computed_blocks
=
[]
...
@@ -101,7 +104,11 @@ class KVCacheManager:
...
@@ -101,7 +104,11 @@ class KVCacheManager:
else
:
else
:
break
break
return
computed_blocks
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
return
computed_blocks
,
num_computed_tokens
def
append_slots
(
def
append_slots
(
self
,
self
,
...
...
vllm/v1/core/scheduler.py
View file @
994fc655
...
@@ -184,12 +184,8 @@ class Scheduler:
...
@@ -184,12 +184,8 @@ class Scheduler:
request
=
self
.
waiting
[
0
]
request
=
self
.
waiting
[
0
]
# Get already-cached tokens.
# Get already-cached tokens.
computed_blocks
=
self
.
kv_cache_manager
.
get_computed_blocks
(
computed_blocks
,
num_computed_tokens
=
\
request
)
self
.
kv_cache_manager
.
get_computed_blocks
(
request
)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
# Number of tokens to be scheduled.
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
# `request.num_prompt_tokens` to consider the resumed requests,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment