Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
994fc655
Unverified
Commit
994fc655
authored
Jan 15, 2025
by
Chen Zhang
Committed by
GitHub
Jan 15, 2025
Browse files
[V1][Prefix Cache] Move the logic of num_computed_tokens into KVCacheManager (#12003)
parent
3f9b7ab9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
35 deletions
+61
-35
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+47
-24
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+12
-5
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+2
-6
No files found.
tests/v1/core/test_prefix_caching.py
View file @
994fc655
...
...
@@ -49,9 +49,10 @@ def test_prefill():
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
...
...
@@ -73,9 +74,10 @@ def test_prefill():
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
kv_block_hashes
)
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
,
6
]
...
...
@@ -91,7 +93,7 @@ def test_prefill():
# All blocks should be available.
assert
manager
.
free_block_queue
.
num_free_blocks
==
10
# The order should be
# [unallocated (7, 8)]
# [unallocated (7, 8
, 9
)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
...
...
@@ -103,9 +105,10 @@ def test_prefill():
# Incomplete 1 block (6 tokens)
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
req2
.
kv_block_hashes
)
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
7
,
8
]
...
...
@@ -123,8 +126,9 @@ def test_prefill():
# Cache miss and eviction.
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
9
))
computed_blocks
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req3
,
16
*
9
,
computed_blocks
)
# This block ID order also checks the eviction order.
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
9
,
4
,
3
,
6
,
5
,
8
,
7
,
2
,
1
,
0
]
...
...
@@ -150,8 +154,9 @@ def test_decode():
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
...
...
@@ -197,16 +202,18 @@ def test_evict():
last_token_id
=
5
*
16
+
7
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
5
*
16
+
7
,
computed_blocks
)
assert
len
(
blocks
)
==
7
# 5 full + 1 partial + 1 preallocated
# 3 blocks.
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
last_token_id
+
3
*
16
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
3
*
16
,
computed_blocks
)
assert
len
(
blocks
)
==
3
# 3 full blocks
last_token_id
+=
3
*
16
...
...
@@ -222,8 +229,9 @@ def test_evict():
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
]
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
,
5
]
assert
manager
.
free_block_queue
.
num_free_blocks
==
6
...
...
@@ -247,8 +255,9 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
num_tokens
=
block_size
*
1
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
,
computed_blocks
)
assert
len
(
blocks
)
==
1
...
...
@@ -258,8 +267,9 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
-
1
,
computed_blocks
)
assert
len
(
blocks
)
==
1
...
...
@@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
num_tokens
=
block_size
*
1
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
num_tokens
,
computed_blocks
)
assert
len
(
blocks
)
==
1
assert
blocks
[
0
].
block_id
==
0
# Allocate another block.
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
num_tokens
,
computed_blocks
)
assert
len
(
blocks
)
==
1
assert
blocks
[
0
].
block_id
==
1
...
...
@@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
)
==
1
assert
computed_blocks
[
0
].
block_id
==
0
assert
num_computed_tokens
==
block_size
blocks
=
manager
.
allocate_slots
(
req2
,
num_tokens
*
2
-
num_tokens
,
computed_blocks
)
...
...
@@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled():
req1
=
make_request
(
"1"
,
list
(
range
(
10
)))
# 2 blocks and some more
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
10
,
computed_blocks
)
assert
len
(
blocks
)
==
3
...
...
@@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled():
# No caching.
req2
=
make_request
(
"2"
,
list
(
range
(
16
)))
# shared prefix
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req2
,
16
,
computed_blocks
)
assert
len
(
blocks
)
==
4
# New requests should not have any blocks.
req3
=
make_request
(
"3"
,
list
(
range
(
4
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req3
,
4
,
computed_blocks
)
assert
not
blocks
...
...
@@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
num_preallocated_blocks
=
cdiv
(
num_preallocate_tokens
,
block_size
)
req
=
make_request
(
"0"
,
list
(
range
(
block_size
*
30
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
# Just ask for 1 block.
blocks
=
manager
.
allocate_slots
(
req
,
block_size
,
computed_blocks
)
req
.
num_computed_tokens
=
block_size
...
...
@@ -469,10 +486,11 @@ def test_mm_prefix_caching():
all_token_ids
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes with extra keys.
assert
not
computed_blocks
assert
num_computed_tokens
==
0
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
req0
.
kv_block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
req0
.
kv_block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
...
...
@@ -503,8 +521,9 @@ def test_mm_prefix_caching():
all_token_ids
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
computed_blocks
)
==
3
assert
num_computed_tokens
==
3
*
16
def
test_prefill_not_enough_free_blocks_with_computed_blocks
():
...
...
@@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
req0
=
make_request
(
"0"
,
common_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
manager
.
allocate_slots
(
req0
,
48
,
computed_blocks
)
block_part0
=
manager
.
req_to_blocks
[
req0
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1
=
make_request
(
"1"
,
common_token_ids
*
2
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
computed_blocks
==
block_part0
assert
num_computed_tokens
==
3
*
16
manager
.
allocate_slots
(
req1
,
48
,
computed_blocks
)
block_part1
=
manager
.
req_to_blocks
[
req1
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
...
...
@@ -547,8 +568,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
)
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
assert
num_computed_tokens
==
0
manager
.
allocate_slots
(
req2
,
block_size
*
2
,
computed_blocks
)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
...
...
@@ -556,8 +578,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# In this case, the ref_cnt of the computed blocks should not be changed.
assert
manager
.
free_block_queue
.
num_free_blocks
==
5
req3
=
make_request
(
"3"
,
common_token_ids
*
3
)
computed_blocks
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
computed_blocks
==
block_part1
assert
num_computed_tokens
==
6
*
16
# Req3 cannot be allocated.
assert
manager
.
allocate_slots
(
req3
,
48
,
computed_blocks
)
is
None
# Block 0-2 are used by Req 1.
...
...
vllm/v1/core/kv_cache_manager.py
View file @
994fc655
from
collections
import
defaultdict
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
...
...
@@ -69,7 +69,8 @@ class KVCacheManager:
# is finished.
self
.
req_to_blocks
:
Dict
[
str
,
List
[
KVCacheBlock
]]
=
{}
def
get_computed_blocks
(
self
,
request
:
Request
)
->
List
[
KVCacheBlock
]:
def
get_computed_blocks
(
self
,
request
:
Request
)
->
Tuple
[
List
[
KVCacheBlock
],
int
]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
...
...
@@ -77,11 +78,13 @@ class KVCacheManager:
request: The request to get the computed blocks.
Returns:
A list of blocks that are computed for the request.
A tuple containing:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
if
not
self
.
enable_caching
:
# Prefix caching is disabled.
return
[]
return
[]
,
0
computed_blocks
=
[]
...
...
@@ -101,7 +104,11 @@ class KVCacheManager:
else
:
break
return
computed_blocks
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
return
computed_blocks
,
num_computed_tokens
def
append_slots
(
self
,
...
...
vllm/v1/core/scheduler.py
View file @
994fc655
...
...
@@ -184,12 +184,8 @@ class Scheduler:
request
=
self
.
waiting
[
0
]
# Get already-cached tokens.
computed_blocks
=
self
.
kv_cache_manager
.
get_computed_blocks
(
request
)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
computed_blocks
,
num_computed_tokens
=
\
self
.
kv_cache_manager
.
get_computed_blocks
(
request
)
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment