Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
201fc077
Unverified
Commit
201fc077
authored
Nov 07, 2024
by
Cody Yu
Committed by
GitHub
Nov 07, 2024
Browse files
[V1] Prefix caching (take 2) (#9972)
Signed-off-by:
Cody Yu
<
hao.yu.cody@gmail.com
>
parent
42b4f46b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
771 additions
and
66 deletions
+771
-66
benchmarks/benchmark_prefix_caching.py
benchmarks/benchmark_prefix_caching.py
+1
-8
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+219
-0
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+335
-47
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+194
-0
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+21
-11
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+1
-0
No files found.
benchmarks/benchmark_prefix_caching.py
View file @
201fc077
...
...
@@ -118,7 +118,7 @@ def main(args):
random
.
seed
(
args
.
seed
)
if
args
.
dataset_path
is
not
None
:
print
(
f
"Start to sample
{
args
.
num_prompts
}
prompts"
"from {args.dataset_path}"
)
f
"from
{
args
.
dataset_path
}
"
)
filtered_datasets
=
sample_requests
(
dataset_path
=
args
.
dataset_path
,
num_requests
=
args
.
num_prompts
,
...
...
@@ -142,13 +142,6 @@ def main(args):
repeat_count
=
args
.
repeat_count
,
sort
=
args
.
sort
)
print
(
"------warm up------"
)
test_prefix
(
llm
=
llm
,
prompts
=
prompts
,
sampling_params
=
sampling_params
,
)
print
(
"------start generating------"
)
test_prefix
(
llm
=
llm
,
...
...
tests/v1/core/test_prefix_caching.py
0 → 100644
View file @
201fc077
"""Compare the with and without prefix caching."""
from
vllm.inputs
import
DecoderOnlyInputs
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
hash_block_tokens
def
make_request
(
request_id
,
prompt_token_ids
):
return
Request
(
request_id
=
request_id
,
inputs
=
DecoderOnlyInputs
(
prompt_token_ids
=
prompt_token_ids
),
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
eos_token_id
=
100
,
arrival_time
=
0
,
lora_request
=
None
,
)
def
test_prefill
():
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
sliding_window
=
False
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
# Check full block metadata
parent_block_hash
=
None
for
block_id
in
(
0
,
1
,
2
):
block_hash
=
hash_block_tokens
(
parent_block_hash
,
manager
.
block_pool
[
block_id
].
token_ids
)
assert
manager
.
block_pool
[
block_id
].
block_hash
==
block_hash
assert
manager
.
block_pool
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
[
block_id
].
num_hashed_tokens
==
16
*
(
block_id
+
1
)
assert
manager
.
block_pool
[
block_id
].
token_ids
==
tuple
([
block_id
]
*
16
)
parent_block_hash
=
block_hash
# Check partial/preallocated block metadata
for
block_id
in
(
3
,
4
):
assert
manager
.
block_pool
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
[
block_id
].
num_hashed_tokens
==
0
if
block_id
==
3
:
assert
manager
.
block_pool
[
block_id
].
token_ids
==
[
3
]
*
7
else
:
assert
not
manager
.
block_pool
[
block_id
].
token_ids
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
,
6
]
for
block
in
computed_blocks
:
assert
block
.
ref_cnt
==
2
# At this point, we should have 3 free blocks left.
assert
manager
.
free_block_queue
.
num_free_blocks
==
3
manager
.
free
(
req0
)
manager
.
free
(
req1
)
# All blocks should be available.
assert
manager
.
free_block_queue
.
num_free_blocks
==
10
# The order should be
# [unallocated (7, 8)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert
[
b
.
block_id
for
b
in
manager
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
7
,
8
,
9
,
4
,
3
,
6
,
5
,
2
,
1
,
0
]
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_block
=
manager
.
get_computed_blocks
(
req2
)
assert
[
b
.
block_id
for
b
in
computed_block
]
==
[
0
,
1
,
2
]
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
7
,
8
]
# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
assert
manager
.
free_block_queue
.
num_free_blocks
==
5
assert
all
([
b
.
ref_cnt
==
0
for
b
in
manager
.
free_block_queue
.
get_all_free_blocks
()
])
assert
len
([
b
for
b
in
manager
.
free_block_queue
.
get_all_free_blocks
()])
==
5
manager
.
free
(
req2
)
# Cache miss and eviction.
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
9
))
computed_blocks
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
blocks
=
manager
.
allocate_slots
(
req2
,
16
*
9
,
computed_blocks
)
# This block ID order also checks the eviction order.
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
9
,
4
,
3
,
6
,
5
,
8
,
7
,
2
,
1
,
0
]
assert
manager
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
free_block_queue
.
free_list_tail
is
None
def
test_decode
():
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
sliding_window
=
False
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
# Append slots without allocating a new block.
req0
.
num_computed_tokens
=
55
for
_
in
range
(
4
):
req0
.
append_output_token_ids
(
8
)
new_blocks
=
manager
.
append_slots
(
req0
,
4
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
0
assert
len
(
manager
.
block_pool
[
3
].
token_ids
)
==
11
# Append slots without allocating a new block, but start using the
# preallocated block.
req0
.
num_computed_tokens
=
59
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for
_
in
range
(
5
+
10
):
req0
.
append_output_token_ids
(
7
)
new_blocks
=
manager
.
append_slots
(
req0
,
15
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
0
assert
len
(
manager
.
block_pool
[
3
].
token_ids
)
==
16
assert
len
(
manager
.
block_pool
[
4
].
token_ids
)
==
10
# Append slots with allocating a new block.
req0
.
num_computed_tokens
=
74
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for
_
in
range
(
6
+
11
):
req0
.
append_output_token_ids
(
12
)
new_blocks
=
manager
.
append_slots
(
req0
,
17
)
# Plus one preallocated block.
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
2
assert
len
(
manager
.
block_pool
[
4
].
token_ids
)
==
16
assert
len
(
manager
.
block_pool
[
5
].
token_ids
)
==
11
assert
len
(
manager
.
block_pool
[
6
].
token_ids
)
==
0
def
test_evict
():
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
sliding_window
=
False
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
last_token_id
=
5
*
16
+
7
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
blocks
=
manager
.
allocate_slots
(
req0
,
5
*
16
+
7
,
computed_blocks
)
assert
len
(
blocks
)
==
7
# 5 full + 1 partial + 1 preallocated
# 3 blocks.
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
last_token_id
+
3
*
16
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
blocks
=
manager
.
allocate_slots
(
req1
,
3
*
16
,
computed_blocks
)
assert
len
(
blocks
)
==
3
# 3 full blocks
last_token_id
+=
3
*
16
assert
manager
.
free_block_queue
.
num_free_blocks
==
0
manager
.
free
(
req0
)
manager
.
free
(
req1
)
assert
manager
.
free_block_queue
.
num_free_blocks
==
10
assert
[
b
.
block_id
for
b
in
manager
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
6
,
5
,
4
,
3
,
2
,
1
,
0
,
9
,
8
,
7
]
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
computed_blocks
=
manager
.
get_computed_blocks
(
req2
)
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
]
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
,
5
]
assert
manager
.
free_block_queue
.
num_free_blocks
==
6
vllm/v1/core/kv_cache_manager.py
View file @
201fc077
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
import
numpy
as
np
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
hash_block_tokens
,
hash_request_tokens
)
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
...
...
@@ -36,73 +38,359 @@ class KVCacheManager:
self
.
num_preallocate_tokens
=
num_preallocate_tokens
self
.
num_preallocate_blocks
=
cdiv
(
num_preallocate_tokens
,
block_size
)
self
.
free_block_ids
=
list
(
range
(
num_gpu_blocks
))
self
.
req_to_block_ids
:
Dict
[
str
,
List
[
int
]]
=
{}
self
.
ref_cnts
=
np
.
zeros
(
num_gpu_blocks
,
dtype
=
np
.
int32
)
# A Block pool of all kv-cache blocks.
self
.
block_pool
:
List
[
KVCacheBlock
]
=
[
KVCacheBlock
(
idx
)
for
idx
in
range
(
num_gpu_blocks
)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self
.
free_block_queue
=
FreeKVCacheBlockQueue
(
self
.
block_pool
)
# {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
# The cached block may be used by running requests or in the
# free_block_queue that could potentially be evicted.
# NOTE: We currently don't de-duplicate the blocks in the cache,
# meaning that if a block becomes full and is cached, we don't check
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self
.
cached_block_hash_to_block
:
Dict
[
BlockHashType
,
Dict
[
int
,
KVCacheBlock
]]
=
defaultdict
(
dict
)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self
.
req_to_blocks
:
Dict
[
str
,
List
[
KVCacheBlock
]]
=
{}
def
get_computed_blocks
(
self
,
request
:
Request
)
->
List
[
int
]:
def
get_computed_blocks
(
self
,
request
:
Request
)
->
List
[
KVCacheBlock
]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
Args:
request: The request to get the computed blocks.
Returns:
A list of blocks that are computed for the request.
"""
if
not
self
.
enable_caching
:
#
No p
refix caching.
#
P
refix caching
is disabled
.
return
[]
# TODO(woosuk): Implement hash-based caching.
return
[]
computed_blocks
=
[]
block_hashes
=
hash_request_tokens
(
self
.
block_size
,
request
.
all_token_ids
)
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if
cached_block
:
=
self
.
_get_cached_block
(
block_hash
):
computed_blocks
.
append
(
cached_block
)
else
:
break
return
computed_blocks
def
append_slots
(
self
,
request
:
Request
,
num_tokens
:
int
,
)
->
Optional
[
List
[
int
]]:
)
->
Optional
[
List
[
KVCacheBlock
]]:
"""Append slots to the block table of the request.
We first append slots to already allocated blocks. If the allocated
blocks are not enough, we allocate new blocks.
Args:
request: The request to append slots.
num_tokens: The number of tokens to append.
Returns:
A list of new blocks if new blocks are allocated, or None
if new blocks are required but cannot be allocated.
"""
num_required_blocks
=
cdiv
(
request
.
num_computed_tokens
+
num_tokens
,
self
.
block_size
)
req_block_ids
=
self
.
req_to_block_ids
[
request
.
request_id
]
if
num_required_blocks
<=
len
(
req_block_ids
):
# No new block is needed.
return
[]
req_blocks
=
self
.
req_to_blocks
[
request
.
request_id
]
num_new_blocks
=
num_required_blocks
-
len
(
req_block
_id
s
)
num_
free
_blocks
=
len
(
self
.
free_block_
ids
)
if
num_
new
_
blocks
>
num_free_blocks
:
#
C
annot allocate new blocks.
num_new_blocks
=
num_required_blocks
-
len
(
req_blocks
)
if
num_
new
_blocks
>
self
.
free_block_
queue
.
num_free_blocks
:
# Need to allocate
new
blocks
due to insufficient pre-allocated
#
slots, but we c
annot allocate new blocks
due to the limit
.
return
None
# Allocate new blocks.
# When caching is enabled, assign token IDs to already allocated blocks.
new_token_ids
=
None
parent_block
=
None
if
self
.
enable_caching
:
# Figure out the token IDs to add to the blocks.
new_token_ids
=
request
.
all_token_ids
[
request
.
num_computed_tokens
:
request
.
num_computed_tokens
+
num_tokens
]
# Find the last full block index.
# TODO: This may be optimized by calculating the computed tokens.
last_full_block_idx
=
len
(
req_blocks
)
-
1
while
(
last_full_block_idx
>=
0
and
req_blocks
[
last_full_block_idx
].
block_hash
is
None
):
last_full_block_idx
-=
1
parent_block
=
(
req_blocks
[
last_full_block_idx
]
if
last_full_block_idx
>=
0
else
None
)
token_id_idx
=
self
.
_add_token_ids_to_blocks
(
blocks
=
req_blocks
[
last_full_block_idx
+
1
:],
token_ids
=
new_token_ids
,
parent_block
=
parent_block
)
new_token_ids
=
new_token_ids
[
token_id_idx
:]
parent_block
=
req_blocks
[
-
1
]
# No new block is needed. When caching is enabled, we make sure
# token_id_idx is equal to len(new_token_ids), meaning that all tokens
# are added to allocated blocks.
if
num_required_blocks
<=
len
(
req_blocks
):
assert
not
self
.
enable_caching
or
token_id_idx
==
num_tokens
,
\
f
"
{
token_id_idx
=
}
!=
{
num_tokens
=
}
"
return
[]
# Allocate new blocks considering preallocated blocks, and
# add token IDs to them if caching is enabled.
num_new_blocks
=
min
(
num_new_blocks
+
self
.
num_preallocate_blocks
,
num_free_blocks
)
new_block
_id
s
=
self
.
_get_new_blocks
(
num_new_blocks
)
req_block_ids
.
extend
(
new
_block
_ids
)
self
.
ref_cnts
[
new_block_ids
]
+=
1
return
new_block
_id
s
self
.
free_block_queue
.
num_free_blocks
)
new_blocks
=
self
.
_get_new_blocks
(
num_new_blocks
,
new_token_ids
,
parent
_block
)
req_blocks
.
extend
(
new_blocks
)
return
new_blocks
def
allocate_slots
(
self
,
request
:
Request
,
num_tokens
:
int
,
computed_block_ids
:
List
[
int
],
)
->
Optional
[
List
[
int
]]:
computed_blocks
:
List
[
KVCacheBlock
],
)
->
Optional
[
List
[
KVCacheBlock
]]:
"""Allocate slots for a new request.
Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: The blocks that have already been computed.
Returns:
A list of new allocated blocks.
"""
if
num_tokens
==
0
:
raise
ValueError
(
f
"num_tokens must be greater than 0, got
{
num_tokens
}
"
)
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks
=
len
(
[
blk
for
blk
in
computed_blocks
if
blk
.
ref_cnt
==
0
])
num_required_blocks
=
cdiv
(
num_tokens
,
self
.
block_size
)
num_
f
re
e
_blocks
=
len
(
self
.
free_block_
ids
)
if
num_required_blocks
>
num_free
_blocks
:
if
(
num_re
quired
_blocks
>
self
.
free_block_
queue
.
num_free_blocks
-
num_evictable_computed
_blocks
)
:
# Cannot allocate new blocks.
return
None
num_new_blocks
=
min
(
num_required_blocks
+
self
.
num_preallocate_blocks
,
num_free_blocks
)
new_block_ids
=
self
.
_get_new_blocks
(
num_new_blocks
)
block_ids
=
computed_block_ids
+
new_block_ids
self
.
req_to_block_ids
[
request
.
request_id
]
=
block_ids
self
.
ref_cnts
[
block_ids
]
+=
1
return
new_block_ids
# Determine the number of new blocks to allocate considering
# preallocated blocks.
num_new_blocks
=
min
(
num_required_blocks
+
self
.
num_preallocate_blocks
,
self
.
free_block_queue
.
num_free_blocks
-
num_evictable_computed_blocks
)
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
# When caching is enabled, get the new token IDs and the parent block
# ID to generate cache keys.
new_token_ids
=
None
parent_block
=
None
if
self
.
enable_caching
:
# Touch the computed blocks to make sure they won't be evicted.
self
.
_touch
(
computed_blocks
)
# Get the token IDs for the blocks being allocated for hashing.
new_token_ids
=
request
.
all_token_ids
[
num_computed_tokens
:
num_computed_tokens
+
num_tokens
]
if
not
new_token_ids
:
raise
RuntimeError
(
"Failed to infer the token IDs for allocation. "
f
"#all_tokens=
{
len
(
request
.
all_token_ids
)
}
< "
f
"#computed_tokens=
{
num_computed_tokens
}
"
)
# Get the parent block ID to construct the block chain.
parent_block
=
computed_blocks
[
-
1
]
if
computed_blocks
else
None
new_blocks
=
self
.
_get_new_blocks
(
num_new_blocks
,
new_token_ids
,
parent_block
)
# Concatenate the computed block IDs and the new block IDs.
self
.
req_to_blocks
[
request
.
request_id
]
=
computed_blocks
+
new_blocks
return
new_blocks
def
free
(
self
,
request
:
Request
)
->
None
:
block_ids
=
self
.
req_to_block_ids
.
pop
(
request
.
request_id
)
self
.
ref_cnts
[
block_ids
]
-=
1
for
block_id
in
block_ids
:
ref_cnt
=
self
.
ref_cnts
[
block_id
]
if
ref_cnt
==
0
:
self
.
free_block_ids
.
append
(
block_id
)
def
_get_new_blocks
(
self
,
num_blocks
:
int
)
->
List
[
int
]:
assert
num_blocks
<=
len
(
self
.
free_block_ids
)
new_block_ids
=
self
.
free_block_ids
[
-
num_blocks
:]
self
.
free_block_ids
=
self
.
free_block_ids
[:
-
num_blocks
]
return
new_block_ids
"""Free the blocks allocated for the request.
When caching is enabled, we free the blocks in reverse order so that
the tail blocks are evicted first.
Args:
request: The request to free the blocks.
"""
blocks
=
self
.
req_to_blocks
.
pop
(
request
.
request_id
)
if
self
.
enable_caching
:
# Free blocks in reverse order so that the tail blocks are
# freed first.
blocks
=
reversed
(
blocks
)
for
block
in
blocks
:
block
.
ref_cnt
-=
1
if
block
.
ref_cnt
==
0
:
self
.
free_block_queue
.
append
(
block
)
def
_get_new_blocks
(
self
,
num_blocks
:
int
,
token_ids
:
Optional
[
List
[
int
]]
=
None
,
parent_block
:
Optional
[
int
]
=
None
)
->
List
[
KVCacheBlock
]:
"""Get new blocks from the free block pool, and add token IDs to
allocated blocks if caching is enabled.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
token_ids: The token IDs in the blocks. None if caching is disabled.
parent_block: The parent block. Used to include block chain
in the block hash.
Returns:
A list of new block.
"""
if
num_blocks
>
self
.
free_block_queue
.
num_free_blocks
:
raise
ValueError
(
f
"Cannot get
{
num_blocks
}
free blocks from the pool"
)
# First allocate blocks.
ret
:
List
[
KVCacheBlock
]
=
[]
idx
=
0
while
idx
<
num_blocks
:
curr_block
=
self
.
free_block_queue
.
popleft
()
assert
curr_block
.
ref_cnt
==
0
# Evict blocks from the cache.
if
self
.
enable_caching
:
block_hash
=
curr_block
.
block_hash
if
(
block_hash
is
not
None
and
block_hash
in
self
.
cached_block_hash_to_block
):
if
len
(
self
.
cached_block_hash_to_block
[
block_hash
])
==
1
:
del
self
.
cached_block_hash_to_block
[
block_hash
]
else
:
del
self
.
cached_block_hash_to_block
[
block_hash
][
curr_block
.
block_id
]
curr_block
.
reset
()
curr_block
.
ref_cnt
=
1
ret
.
append
(
curr_block
)
idx
+=
1
# Then assign token IDs to the allocated blocks.
if
self
.
enable_caching
:
assert
token_ids
is
not
None
token_id_idx
=
self
.
_add_token_ids_to_blocks
(
blocks
=
ret
,
token_ids
=
token_ids
,
parent_block
=
parent_block
)
assert
token_id_idx
==
len
(
token_ids
)
return
ret
def
_cache_full_block
(
self
,
block
:
KVCacheBlock
,
parent_block
:
Optional
[
KVCacheBlock
]
=
None
)
->
None
:
"""Cache a full block for prefix caching.
Args:
block: The block to cache.
parent_block: The parent block. None if this is the first block.
"""
parent_block_hash
=
(
parent_block
.
block_hash
if
parent_block
is
not
None
else
None
)
assert
len
(
block
.
token_ids
)
==
self
.
block_size
block
.
token_ids
=
tuple
(
block
.
token_ids
)
block_hash
=
hash_block_tokens
(
parent_block_hash
,
block
.
token_ids
)
block
.
block_hash
=
block_hash
block
.
num_hashed_tokens
=
self
.
block_size
+
(
parent_block
.
num_hashed_tokens
if
parent_block
is
not
None
else
0
)
self
.
cached_block_hash_to_block
[
block_hash
][
block
.
block_id
]
=
block
def
_get_cached_block
(
self
,
block_hash
:
BlockHashType
)
->
Optional
[
KVCacheBlock
]:
"""Get a cached block by the block hash, or None if cache miss.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
Returns:
The cached block if it exists, or None.
"""
if
block_hash
in
self
.
cached_block_hash_to_block
:
first_block_id
=
list
(
self
.
cached_block_hash_to_block
[
block_hash
].
keys
())[
0
]
return
self
.
cached_block_hash_to_block
[
block_hash
][
first_block_id
]
return
None
def
_touch
(
self
,
blocks
:
List
[
KVCacheBlock
])
->
None
:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Args:
blocks: A list of blocks to touch.
"""
for
block
in
blocks
:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if
block
.
ref_cnt
==
0
:
self
.
free_block_queue
.
remove
(
block
)
block
.
ref_cnt
+=
1
def
_add_token_ids_to_blocks
(
self
,
blocks
:
List
[
KVCacheBlock
],
token_ids
:
List
[
int
],
parent_block
:
Optional
[
KVCacheBlock
]
=
None
)
->
int
:
"""Add token IDs to a list of allocated blocks.
If a block becomes full after adding token IDs, cache it.
Return the token ID index that has not been added to the blocks
if the blocks are not enough to hold all the token IDs.
Args:
blocks: A list of blocks to add token IDs.
token_ids: A list of token IDs to add.
parent_block: The parent block. None if this is the
first block.
Returns:
The starting token ID index that has not been added to the blocks
due to insufficient given blocks.
"""
token_id_start
=
0
for
curr_block
in
blocks
:
# If all token IDs are added, then the rest of the blocks are
# preallocated blocks, so we only need to update the
# parent_block_id. FIXME
if
token_id_start
==
len
(
token_ids
):
continue
# Add token IDs to the empty slots in the block.
empty_slots
=
self
.
block_size
-
len
(
curr_block
.
token_ids
)
token_id_end
=
min
(
token_id_start
+
empty_slots
,
len
(
token_ids
))
curr_block
.
token_ids
.
extend
(
token_ids
[
token_id_start
:
token_id_end
])
# Cache the block if it becomes full.
if
len
(
curr_block
.
token_ids
)
==
self
.
block_size
:
self
.
_cache_full_block
(
curr_block
,
parent_block
)
parent_block
=
curr_block
token_id_start
=
token_id_end
return
token_id_start
vllm/v1/core/kv_cache_utils.py
0 → 100644
View file @
201fc077
"""KV-Cache Utilities."""
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
BlockHashType
=
Tuple
[
int
,
Tuple
[
int
]]
@
dataclass
class
KVCacheBlock
:
"""KV-cache block metadata."""
# Block ID, ranging from 0 to num_gpu_blocks - 1.
block_id
:
int
# Reference count.
ref_cnt
:
int
=
0
# Token IDs in the block. When the block is full, the type of token_ids
# should be Tuple[int] for fast matching.
token_ids
:
Union
[
List
[
int
],
Tuple
[
int
]]
=
field
(
default_factory
=
list
)
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
block_hash
:
Optional
[
BlockHashType
]
=
None
# The number of hashed tokens. More hashed tokens means the block
# is closer to the end of a prompt and more likely to be evicted.
num_hashed_tokens
:
int
=
0
# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
prev_free_block
:
Optional
[
"KVCacheBlock"
]
=
None
next_free_block
:
Optional
[
"KVCacheBlock"
]
=
None
def
reset
(
self
):
"""Reset the block metadata."""
self
.
ref_cnt
=
0
self
.
token_ids
=
[]
self
.
block_hash
=
None
self
.
num_hashed_tokens
=
0
class
FreeKVCacheBlockQueue
:
"""This class organizes a list of KVCacheBlock objects to a doubly linked
list of free blocks. We implement this class instead of using Python
builtin deque to support removing a block in the middle of the queue
in O(1) time. To close the performance gap to the builtin deque which is
implemented in C++, this class does not allocate any Python objects when
manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks.
The queue is ordered by block ID in the beginning. When a block is allocated
and then freed, it will be appended back with the eviction order:
1. The least recent used block is at the front (LRU).
2. If two blocks have the same last accessed time (allocated by the
same sequence), the one with more hash tokens (the tail of a block
chain) is at the front.
Note that we maintain this order by reversing the block order when free
blocks of a request. This operation is outside of this class.
Args:
blocks: A list of KVCacheBlock objects.
"""
def
__init__
(
self
,
blocks
:
List
[
KVCacheBlock
])
->
None
:
self
.
num_free_blocks
=
len
(
blocks
)
# Initialize the doubly linked list of free blocks.
self
.
free_list_head
=
blocks
[
0
]
self
.
free_list_tail
=
blocks
[
-
1
]
for
i
in
range
(
self
.
num_free_blocks
):
if
i
>
0
:
blocks
[
i
].
prev_free_block
=
blocks
[
i
-
1
]
if
i
<
self
.
num_free_blocks
-
1
:
blocks
[
i
].
next_free_block
=
blocks
[
i
+
1
]
def
popleft
(
self
)
->
KVCacheBlock
:
"""Pop the first free block and reduce num_free_blocks by 1.
Returns:
The first free block.
"""
if
not
self
.
free_list_head
:
raise
ValueError
(
"No free blocks available"
)
block
=
self
.
free_list_head
self
.
remove
(
block
)
return
block
def
remove
(
self
,
block
:
KVCacheBlock
)
->
None
:
"""Remove a block in the free list and reduce num_free_blocks by 1.
Args:
block: The block to remove.
"""
if
block
.
prev_free_block
is
not
None
:
# Link the previous block to the next block.
block
.
prev_free_block
.
next_free_block
=
block
.
next_free_block
if
block
.
next_free_block
is
not
None
:
# Link the next block to the previous block.
block
.
next_free_block
.
prev_free_block
=
block
.
prev_free_block
if
block
==
self
.
free_list_head
:
# Update the head if the block is the head.
self
.
free_list_head
=
block
.
next_free_block
if
block
==
self
.
free_list_tail
:
# Update the tail if the block is the tail.
self
.
free_list_tail
=
block
.
prev_free_block
# Remove the block from the linked list.
block
.
prev_free_block
=
block
.
next_free_block
=
None
self
.
num_free_blocks
-=
1
def
append
(
self
,
block
:
KVCacheBlock
)
->
None
:
"""Put a block back into the free list and increase
num_free_blocks by 1.
Args:
block: The block to append.
"""
if
self
.
free_list_tail
is
not
None
:
# Link the last block to the new block.
self
.
free_list_tail
.
next_free_block
=
block
block
.
prev_free_block
=
self
.
free_list_tail
self
.
free_list_tail
=
block
else
:
# The free list is empty.
assert
self
.
free_list_head
is
None
self
.
free_list_head
=
self
.
free_list_tail
=
block
block
.
next_free_block
=
None
self
.
num_free_blocks
+=
1
def
get_all_free_blocks
(
self
)
->
List
[
KVCacheBlock
]:
"""Get all free blocks in the free list. Mainly used for testing.
Returns:
A list of free blocks.
"""
ret
=
[]
curr_block
=
self
.
free_list_head
while
curr_block
is
not
None
:
ret
.
append
(
curr_block
)
curr_block
=
curr_block
.
next_free_block
return
ret
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
curr_block_token_ids
:
Tuple
[
int
])
->
BlockHashType
:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
TODO: Support arbitrary metadata so that we could support more
features such as LoRA adapter.
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A tuple of token ids in the current
block. The current block is assumed to be full.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return
(
hash
(
(
parent_block_hash
,
*
curr_block_token_ids
)),
curr_block_token_ids
)
def
hash_request_tokens
(
block_size
:
int
,
token_ids
:
List
[
int
])
->
List
[
BlockHashType
]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
token_ids: A sequence of token ids in the request.
Returns:
The list of computed hash values.
"""
ret
=
[]
parent_block_hash
=
None
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
end
=
start
+
block_size
block_token_ids
=
tuple
(
token_ids
[
start
:
end
])
# Do not hash the block if it is not full.
if
len
(
block_token_ids
)
<
block_size
:
break
block_hash
=
hash_block_tokens
(
parent_block_hash
,
block_token_ids
)
ret
.
append
(
block_hash
)
parent_block_hash
=
block_hash
return
ret
vllm/v1/core/scheduler.py
View file @
201fc077
...
...
@@ -34,7 +34,7 @@ class Scheduler:
block_size
=
self
.
cache_config
.
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
sliding_window
=
self
.
cache_config
.
sliding_window
,
enable_caching
=
True
)
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
)
self
.
block_size
=
self
.
cache_config
.
block_size
# Scheduling constraints.
...
...
@@ -91,9 +91,9 @@ class Scheduler:
assert
num_new_tokens
>
0
while
True
:
new_block
_id
s
=
self
.
kv_cache_manager
.
append_slots
(
new_blocks
=
self
.
kv_cache_manager
.
append_slots
(
request
,
num_new_tokens
)
if
new_block
_id
s
is
None
:
if
new_blocks
is
None
:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
preempted_req
=
self
.
running
.
pop
()
...
...
@@ -110,7 +110,9 @@ class Scheduler:
# The request can be scheduled.
scheduled_running_reqs
.
append
(
request
)
req_to_new_block_ids
[
request
.
request_id
]
=
new_block_ids
req_to_new_block_ids
[
request
.
request_id
]
=
[
b
.
block_id
for
b
in
new_blocks
]
num_scheduled_tokens
[
request
.
request_id
]
=
num_new_tokens
token_budget
-=
num_new_tokens
req_index
+=
1
...
...
@@ -126,22 +128,29 @@ class Scheduler:
request
=
self
.
waiting
[
0
]
# Get already-cached tokens.
computed_block
_id
s
=
self
.
kv_cache_manager
.
get_computed_blocks
(
computed_blocks
=
self
.
kv_cache_manager
.
get_computed_blocks
(
request
)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_block
_id
s
)
*
self
.
block_size
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens
=
request
.
num_tokens
-
num_computed_tokens
if
num_new_tokens
==
0
:
# The happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last token.
num_computed_tokens
-=
1
num_new_tokens
=
1
computed_blocks
.
pop
()
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
assert
num_new_tokens
>
0
new_block
_id
s
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
,
computed_block
_id
s
)
if
new_block
_id
s
is
None
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
,
computed_blocks
)
if
new_blocks
is
None
:
# The request cannot be scheduled.
break
request
.
num_computed_tokens
=
num_computed_tokens
...
...
@@ -156,8 +165,9 @@ class Scheduler:
raise
RuntimeError
(
f
"Invalid request status:
{
request
.
status
}
"
)
req_to_new_block_ids
[
request
.
request_id
]
=
(
computed_block_ids
+
new_block_ids
)
req_to_new_block_ids
[
request
.
request_id
]
=
[
b
.
block_id
for
b
in
computed_blocks
+
new_blocks
]
num_scheduled_tokens
[
request
.
request_id
]
=
num_new_tokens
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
...
...
vllm/v1/engine/llm_engine.py
View file @
201fc077
...
...
@@ -65,6 +65,7 @@ class LLMEngine:
elif
usage_context
==
UsageContext
.
OPENAI_API_SERVER
:
scheduler_config
.
max_num_seqs
=
1024
scheduler_config
.
max_num_batched_tokens
=
2048
cache_config
.
enable_prefix_caching
=
True
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment