Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32431583
Unverified
Commit
32431583
authored
Feb 07, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 07, 2025
Browse files
[V1] Move KV block hashes from Request to KVCacheManager (#12922)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
b21f0f9d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
31 deletions
+35
-31
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+11
-10
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+23
-8
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+1
-0
vllm/v1/request.py
vllm/v1/request.py
+0
-13
No files found.
tests/v1/core/test_prefix_caching.py
View file @
32431583
...
@@ -51,7 +51,7 @@ def test_prefill():
...
@@ -51,7 +51,7 @@ def test_prefill():
all_token_ids
=
common_token_ids
+
unique_token_ids
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
req0
=
make_request
(
"0"
,
all_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
]
)
==
3
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
...
@@ -76,7 +76,7 @@ def test_prefill():
...
@@ -76,7 +76,7 @@ def test_prefill():
unique_token_ids
=
[
3
]
*
5
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
kv_block_hashes
)
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
]
)
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
...
@@ -107,7 +107,7 @@ def test_prefill():
...
@@ -107,7 +107,7 @@ def test_prefill():
unique_token_ids
=
[
3
]
*
6
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
req2
.
kv_block_hashes
)
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
]
)
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
...
@@ -494,10 +494,11 @@ def test_mm_prefix_caching():
...
@@ -494,10 +494,11 @@ def test_mm_prefix_caching():
# Completed block should have hashes with extra keys.
# Completed block should have hashes with extra keys.
assert
not
computed_blocks
assert
not
computed_blocks
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
assert
len
(
req0
.
kv_block_hashes
)
==
3
block_hashes
=
manager
.
req_to_block_hashes
[
req0
.
request_id
]
assert
req0
.
kv_block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
len
(
block_hashes
)
==
3
assert
req0
.
kv_block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
assert
block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
req0
.
kv_block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
assert
block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
assert
block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
...
@@ -510,8 +511,8 @@ def test_mm_prefix_caching():
...
@@ -510,8 +511,8 @@ def test_mm_prefix_caching():
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
0
# The just completed block should have hashes with extra keys.
# The just completed block should have hashes with extra keys.
assert
len
(
req0
.
kv_
block_hashes
)
==
4
assert
len
(
block_hashes
)
==
4
assert
req0
.
kv_
block_hashes
[
3
].
extra_keys
==
(
"ccc"
,
)
assert
block_hashes
[
3
].
extra_keys
==
(
"ccc"
,
)
# Cache hit.
# Cache hit.
unique_token_ids
=
[
-
1
]
*
7
+
[
200
]
*
5
unique_token_ids
=
[
-
1
]
*
7
+
[
200
]
*
5
...
@@ -613,7 +614,7 @@ def test_reset_prefix_cache():
...
@@ -613,7 +614,7 @@ def test_reset_prefix_cache():
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req1
=
make_request
(
"1"
,
all_token_ids
)
req1
=
make_request
(
"1"
,
all_token_ids
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
kv_block_hashes
)
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
]
)
==
3
assert
len
(
computed_blocks
)
==
3
assert
len
(
computed_blocks
)
==
3
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
4
]
...
...
vllm/v1/core/kv_cache_manager.py
View file @
32431583
...
@@ -72,6 +72,12 @@ class KVCacheManager:
...
@@ -72,6 +72,12 @@ class KVCacheManager:
self
.
req_to_blocks
:
DefaultDict
[
str
,
self
.
req_to_blocks
:
DefaultDict
[
str
,
List
[
KVCacheBlock
]]
=
defaultdict
(
list
)
List
[
KVCacheBlock
]]
=
defaultdict
(
list
)
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self
.
req_to_block_hashes
:
DefaultDict
[
str
,
List
[
BlockHashType
]]
=
defaultdict
(
list
)
@
property
@
property
def
usage
(
self
)
->
float
:
def
usage
(
self
)
->
float
:
return
1.0
-
(
self
.
free_block_queue
.
num_free_blocks
/
return
1.0
-
(
self
.
free_block_queue
.
num_free_blocks
/
...
@@ -97,11 +103,11 @@ class KVCacheManager:
...
@@ -97,11 +103,11 @@ class KVCacheManager:
computed_blocks
=
[]
computed_blocks
=
[]
# The block hashes for the request may already be computed
# The block hashes for the request may already be computed
# if the
request was preempted and resumed
.
# if the
scheduler has tried to schedule the request before
.
if
not
request
.
kv_block_hashes
:
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
request
.
set_kv_
block_hashes
(
if
not
block_hashes
:
hash_request_tokens
(
self
.
block_size
,
request
)
)
block_hashes
=
hash_request_tokens
(
self
.
block_size
,
request
)
block_hashes
=
request
.
kv_
block_hashes
self
.
req_to_
block_hashes
[
request
.
request_id
]
=
block_hashes
for
block_hash
in
block_hashes
:
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
# block_hashes is a chain of block hashes. If a block hash is not
...
@@ -435,7 +441,8 @@ class KVCacheManager:
...
@@ -435,7 +441,8 @@ class KVCacheManager:
full_blocks: The list of blocks to update hash metadata.
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
prev_block: The previous block in the chain.
"""
"""
num_cached_block_hashes
=
len
(
request
.
kv_block_hashes
)
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
num_cached_block_hashes
=
len
(
block_hashes
)
# Update the new blocks with the block hashes through the chain.
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value
=
None
prev_block_hash_value
=
None
...
@@ -468,7 +475,7 @@ class KVCacheManager:
...
@@ -468,7 +475,7 @@ class KVCacheManager:
# this request (either the prompt tokens or the previously
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# generated tokens with preemption). In this case we simply
# reuse the block hash.
# reuse the block hash.
block_hash
=
request
.
kv_
block_hashes
[
blk_idx
]
block_hash
=
block_hashes
[
blk_idx
]
else
:
else
:
# Otherwise compute the block hash and cache it in the request
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
# in case it will be preempted in the future.
...
@@ -490,9 +497,17 @@ class KVCacheManager:
...
@@ -490,9 +497,17 @@ class KVCacheManager:
# Compute the hash of the current block.
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_tokens
,
extra_keys
)
block_tokens
,
extra_keys
)
request
.
append_kv_
block_hashes
(
block_hash
)
block_hashes
.
append
(
block_hash
)
# Update and added the full block to the cache.
# Update and added the full block to the cache.
blk
.
block_hash
=
block_hash
blk
.
block_hash
=
block_hash
self
.
cached_block_hash_to_block
[
block_hash
][
blk
.
block_id
]
=
blk
self
.
cached_block_hash_to_block
[
block_hash
][
blk
.
block_id
]
=
blk
prev_block_hash_value
=
block_hash
.
hash_value
prev_block_hash_value
=
block_hash
.
hash_value
def
free_block_hashes
(
self
,
request
:
Request
)
->
None
:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self
.
req_to_block_hashes
.
pop
(
request
.
request_id
,
None
)
vllm/v1/core/scheduler.py
View file @
32431583
...
@@ -579,6 +579,7 @@ class Scheduler:
...
@@ -579,6 +579,7 @@ class Scheduler:
def
_free_request
(
self
,
request
:
Request
)
->
None
:
def
_free_request
(
self
,
request
:
Request
)
->
None
:
assert
request
.
is_finished
()
assert
request
.
is_finished
()
self
.
kv_cache_manager
.
free
(
request
)
self
.
kv_cache_manager
.
free
(
request
)
self
.
kv_cache_manager
.
free_block_hashes
(
request
)
self
.
encoder_cache_manager
.
free
(
request
)
self
.
encoder_cache_manager
.
free
(
request
)
self
.
_cached_reqs_data
.
pop
(
request
.
request_id
,
None
)
self
.
_cached_reqs_data
.
pop
(
request
.
request_id
,
None
)
del
self
.
requests
[
request
.
request_id
]
del
self
.
requests
[
request
.
request_id
]
...
...
vllm/v1/request.py
View file @
32431583
...
@@ -12,7 +12,6 @@ from vllm.v1.utils import ConstantList
...
@@ -12,7 +12,6 @@ from vllm.v1.utils import ConstantList
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.v1.core.kv_cache_utils
import
BlockHashType
class
Request
:
class
Request
:
...
@@ -63,11 +62,6 @@ class Request:
...
@@ -63,11 +62,6 @@ class Request:
if
self
.
mm_hashes
:
if
self
.
mm_hashes
:
assert
len
(
self
.
mm_inputs
)
==
len
(
self
.
mm_hashes
)
assert
len
(
self
.
mm_inputs
)
==
len
(
self
.
mm_hashes
)
# Cache the computed kv block hashes of the request to avoid
# recomputing.
self
.
_kv_block_hashes
:
List
[
BlockHashType
]
=
[]
self
.
kv_block_hashes
=
ConstantList
(
self
.
_kv_block_hashes
)
# Read-only views
# Read-only views
# Prevent directly appending to the these lists since
# Prevent directly appending to the these lists since
# they should also be updated simultaneously.
# they should also be updated simultaneously.
...
@@ -124,13 +118,6 @@ class Request:
...
@@ -124,13 +118,6 @@ class Request:
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
return
num_tokens
return
num_tokens
def
set_kv_block_hashes
(
self
,
value
:
List
[
"BlockHashType"
])
->
None
:
self
.
_kv_block_hashes
=
value
self
.
kv_block_hashes
=
ConstantList
(
self
.
_kv_block_hashes
)
def
append_kv_block_hashes
(
self
,
block_hash
:
"BlockHashType"
)
->
None
:
self
.
_kv_block_hashes
.
append
(
block_hash
)
class
RequestStatus
(
enum
.
IntEnum
):
class
RequestStatus
(
enum
.
IntEnum
):
"""Status of a request."""
"""Status of a request."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment