Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0f199f19
Unverified
Commit
0f199f19
authored
Jul 18, 2025
by
JialinOuyang-Meta
Committed by
GitHub
Jul 18, 2025
Browse files
[Core] Avoid KVCacheBlock.__eq__ invocations in FreeKVCacheBlockQueue (#21005)
Signed-off-by:
Jialin Ouyang
<
jialino@meta.com
>
parent
b2eb2b5a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
210 additions
and
58 deletions
+210
-58
benchmarks/kv_cache/benchmark_block_pool.py
benchmarks/kv_cache/benchmark_block_pool.py
+108
-0
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+15
-13
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+13
-13
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+74
-32
No files found.
benchmarks/kv_cache/benchmark_block_pool.py
0 → 100644
View file @
0f199f19
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
time
from
typing
import
Optional
from
tabulate
import
tabulate
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.v1.core.block_pool
import
BlockPool
class
Metric
:
def
__init__
(
self
)
->
None
:
self
.
cnt
:
int
=
0
self
.
sum_v
:
int
=
0
self
.
max_v
:
Optional
[
int
]
=
None
def
update
(
self
,
v
:
int
)
->
None
:
self
.
cnt
+=
1
self
.
sum_v
+=
v
if
self
.
max_v
is
None
:
self
.
max_v
=
v
else
:
self
.
max_v
=
max
(
self
.
max_v
,
v
)
def
avg_v
(
self
)
->
float
:
return
self
.
sum_v
*
1.0
/
self
.
cnt
def
main
(
args
):
rows
=
[]
for
allocate_block
in
args
.
allocate_blocks
:
# Enforce a GC collect ahead to minimize the impact among runs
gc
.
collect
()
block_pool
=
BlockPool
(
num_gpu_blocks
=
args
.
num_gpu_blocks
,
enable_caching
=
True
)
get_blocks_metric
:
Metric
=
Metric
()
free_blocks_metric
:
Metric
=
Metric
()
for
_
in
range
(
args
.
num_iteration
):
t1
=
time
.
monotonic_ns
()
blocks
=
block_pool
.
get_new_blocks
(
allocate_block
)
t2
=
time
.
monotonic_ns
()
block_pool
.
free_blocks
(
blocks
)
t3
=
time
.
monotonic_ns
()
get_blocks_metric
.
update
(
t2
-
t1
)
free_blocks_metric
.
update
(
t3
-
t2
)
if
get_blocks_metric
.
max_v
is
not
None
and
free_blocks_metric
.
max_v
is
not
None
:
rows
.
append
(
[
get_blocks_metric
.
cnt
,
args
.
num_gpu_blocks
,
allocate_block
,
get_blocks_metric
.
avg_v
()
/
1000000
,
get_blocks_metric
.
max_v
/
1000000.0
,
free_blocks_metric
.
avg_v
()
/
1000000
,
free_blocks_metric
.
max_v
/
1000000.0
,
]
)
else
:
print
(
"No valid metrics found."
f
"
{
get_blocks_metric
.
max_v
=
}
{
free_blocks_metric
.
max_v
=
}
"
)
print
(
tabulate
(
rows
,
headers
=
[
"Iterations"
,
"Total
\n
Blocks"
,
"Allocated
\n
Blocks"
,
"Get Blocks
\n
Avg (ms)"
,
"Get Blocks
\n
Max (ms)"
,
"Free Blocks
\n
Avg (ms)"
,
"Free Blocks
\n
Max (ms)"
,
],
tablefmt
=
"grid"
,
floatfmt
=
".6f"
,
)
)
def
invoke_main
()
->
None
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the performance of BlockPool for KV Cache."
)
parser
.
add_argument
(
"--num-gpu-blocks"
,
type
=
int
,
default
=
100000
)
parser
.
add_argument
(
"--num-iteration"
,
type
=
int
,
default
=
1000
,
help
=
"Number of iterations to run to stablize final data readings"
,
)
parser
.
add_argument
(
"--allocate-blocks"
,
type
=
int
,
nargs
=
"*"
,
default
=
[
10
,
50
,
100
,
500
,
1000
],
help
=
"Number of blocks to allocate"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
if
__name__
==
"__main__"
:
invoke_main
()
# pragma: no cover
tests/v1/core/test_kv_cache_utils.py
View file @
0f199f19
...
@@ -132,8 +132,8 @@ def test_free_kv_cache_block_queue_initialization():
...
@@ -132,8 +132,8 @@ def test_free_kv_cache_block_queue_initialization():
block
=
KVCacheBlock
(
block_id
=
0
)
block
=
KVCacheBlock
(
block_id
=
0
)
queue
=
FreeKVCacheBlockQueue
([
block
])
queue
=
FreeKVCacheBlockQueue
([
block
])
assert
queue
.
num_free_blocks
==
1
assert
queue
.
num_free_blocks
==
1
assert
queue
.
free_list_head
==
block
assert
queue
.
fake_
free_list_head
.
next_free_block
is
block
assert
queue
.
free_list_tail
==
block
assert
queue
.
fake_
free_list_tail
.
prev_free_block
is
block
def
test_free_kv_cache_block_queue_operations
():
def
test_free_kv_cache_block_queue_operations
():
...
@@ -145,36 +145,38 @@ def test_free_kv_cache_block_queue_operations():
...
@@ -145,36 +145,38 @@ def test_free_kv_cache_block_queue_operations():
# Check initial state
# Check initial state
assert
queue
.
num_free_blocks
==
5
assert
queue
.
num_free_blocks
==
5
assert
queue
.
free_list_head
==
blocks
[
0
]
assert
queue
.
fake_
free_list_head
.
next_free_block
is
blocks
[
0
]
assert
queue
.
free_list_tail
==
blocks
[
4
]
assert
queue
.
fake_
free_list_tail
.
prev_free_block
is
blocks
[
4
]
# Pop the first block
# Pop the first block
block1
=
queue
.
popleft
()
block1
=
queue
.
popleft
()
assert
block1
==
blocks
[
0
]
assert
block1
==
blocks
[
0
]
assert
queue
.
num_free_blocks
==
4
assert
queue
.
num_free_blocks
==
4
assert
queue
.
free_list_head
==
blocks
[
1
]
assert
queue
.
fake_
free_list_head
.
next_free_block
is
blocks
[
1
]
assert
queue
.
free_list_tail
==
blocks
[
4
]
assert
queue
.
fake_
free_list_tail
.
prev_free_block
is
blocks
[
4
]
# Remove a block from the middle
# Remove a block from the middle
block_to_remove
=
blocks
[
2
]
block_to_remove
=
blocks
[
2
]
queue
.
remove
(
block_to_remove
)
queue
.
remove
(
block_to_remove
)
assert
queue
.
num_free_blocks
==
3
assert
queue
.
num_free_blocks
==
3
assert
blocks
[
1
].
next_free_block
==
blocks
[
3
]
assert
blocks
[
1
].
next_free_block
is
blocks
[
3
]
assert
blocks
[
3
].
prev_free_block
==
blocks
[
1
]
assert
blocks
[
3
].
prev_free_block
is
blocks
[
1
]
# Append a block back
# Append a block back
queue
.
append
(
block_to_remove
)
queue
.
append
(
block_to_remove
)
assert
queue
.
num_free_blocks
==
4
assert
queue
.
num_free_blocks
==
4
assert
queue
.
free_list_tail
==
block_to_remove
assert
queue
.
fake_
free_list_tail
.
prev_free_block
is
block_to_remove
assert
block_to_remove
.
prev_free_block
==
blocks
[
4
]
assert
block_to_remove
.
prev_free_block
is
blocks
[
4
]
assert
block_to_remove
.
next_free_block
is
None
assert
block_to_remove
.
next_free_block
is
queue
.
fake_free_list_tail
# Pop blocks until empty
# Pop blocks until empty
for
_
in
range
(
4
):
for
_
in
range
(
4
):
queue
.
popleft
()
queue
.
popleft
()
assert
queue
.
num_free_blocks
==
0
assert
queue
.
num_free_blocks
==
0
assert
queue
.
free_list_head
is
None
assert
(
queue
.
fake_free_list_head
.
next_free_block
assert
queue
.
free_list_tail
is
None
is
queue
.
fake_free_list_tail
)
assert
(
queue
.
fake_free_list_tail
.
prev_free_block
is
queue
.
fake_free_list_head
)
# Attempt to pop from an empty queue
# Attempt to pop from an empty queue
with
pytest
.
raises
(
ValueError
)
as
e
:
with
pytest
.
raises
(
ValueError
)
as
e
:
...
...
tests/v1/core/test_prefix_caching.py
View file @
0f199f19
...
@@ -155,13 +155,14 @@ def test_prefill(hash_algo):
...
@@ -155,13 +155,14 @@ def test_prefill(hash_algo):
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
# At this point, we should have 5 free blocks left.
# At this point, we should have 5 free blocks left.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
free_block_queue
=
manager
.
block_pool
.
free_block_queue
assert
free_block_queue
.
num_free_blocks
==
5
manager
.
free
(
req0
)
manager
.
free
(
req0
)
manager
.
free
(
req1
)
manager
.
free
(
req1
)
# All blocks should be available.
# All blocks should be available.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
10
assert
free_block_queue
.
num_free_blocks
==
10
# The order should be
# The order should be
# [unallocated (6, 7, 8, 9, 10)]
# [unallocated (6, 7, 8, 9, 10)]
# [unique_req0 (4)]
# [unique_req0 (4)]
...
@@ -188,14 +189,10 @@ def test_prefill(hash_algo):
...
@@ -188,14 +189,10 @@ def test_prefill(hash_algo):
# Although we only have 6 free blocks, we have 8 blocks in
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
# the free block queue due to lazy removal.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
6
assert
free_block_queue
.
num_free_blocks
==
6
assert
all
([
assert
all
(
b
.
ref_cnt
==
0
[
b
.
ref_cnt
==
0
for
b
in
free_block_queue
.
get_all_free_blocks
()])
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
assert
len
([
b
for
b
in
free_block_queue
.
get_all_free_blocks
()])
==
6
])
assert
len
([
b
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
])
==
6
manager
.
free
(
req2
)
manager
.
free
(
req2
)
...
@@ -209,9 +206,12 @@ def test_prefill(hash_algo):
...
@@ -209,9 +206,12 @@ def test_prefill(hash_algo):
computed_blocks
)
computed_blocks
)
# This block ID order also checks the eviction order.
# This block ID order also checks the eviction order.
assert
blocks
.
get_block_ids
()
==
([
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
],
)
assert
blocks
.
get_block_ids
()
==
([
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
],
)
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
assert
(
free_block_queue
.
fake_free_list_head
.
next_free_block
is
free_block_queue
.
fake_free_list_tail
)
assert
(
free_block_queue
.
fake_free_list_tail
.
prev_free_block
is
free_block_queue
.
fake_free_list_head
)
def
test_prefill_hybrid_model
():
def
test_prefill_hybrid_model
():
...
...
vllm/v1/core/kv_cache_utils.py
View file @
0f199f19
...
@@ -212,27 +212,65 @@ class FreeKVCacheBlockQueue:
...
@@ -212,27 +212,65 @@ class FreeKVCacheBlockQueue:
def
__init__
(
self
,
blocks
:
list
[
KVCacheBlock
])
->
None
:
def
__init__
(
self
,
blocks
:
list
[
KVCacheBlock
])
->
None
:
self
.
num_free_blocks
=
len
(
blocks
)
self
.
num_free_blocks
=
len
(
blocks
)
# Initialize the doubly linked list of free blocks.
# Initialize doubly links of consecutive blocks
self
.
free_list_head
:
Optional
[
KVCacheBlock
]
=
blocks
[
0
]
self
.
free_list_tail
:
Optional
[
KVCacheBlock
]
=
blocks
[
-
1
]
for
i
in
range
(
self
.
num_free_blocks
):
for
i
in
range
(
self
.
num_free_blocks
):
if
i
>
0
:
if
i
>
0
:
blocks
[
i
].
prev_free_block
=
blocks
[
i
-
1
]
blocks
[
i
].
prev_free_block
=
blocks
[
i
-
1
]
if
i
<
self
.
num_free_blocks
-
1
:
if
i
<
self
.
num_free_blocks
-
1
:
blocks
[
i
].
next_free_block
=
blocks
[
i
+
1
]
blocks
[
i
].
next_free_block
=
blocks
[
i
+
1
]
# Create a fake head and a tail block for the doubly linked list to
# reduce branching in the code
#
# The implementation garenteed that the fake head and tail
# are NEVER got popped, so we could safely assume each real blocks
# in the queue has prev and next blocks.
self
.
fake_free_list_head
=
KVCacheBlock
(
block_id
=-
1
)
self
.
fake_free_list_tail
=
KVCacheBlock
(
block_id
=-
1
)
if
self
.
num_free_blocks
>
0
:
# Connect fake_head and fake_tail to the first and last block
# respectively.
self
.
fake_free_list_head
.
next_free_block
=
blocks
[
0
]
blocks
[
0
].
prev_free_block
=
self
.
fake_free_list_head
self
.
fake_free_list_tail
.
prev_free_block
=
blocks
[
-
1
]
blocks
[
-
1
].
next_free_block
=
self
.
fake_free_list_tail
else
:
# For empty list, simply connect the fake head and tail.
self
.
fake_free_list_head
.
next_free_block
=
self
.
fake_free_list_tail
self
.
fake_free_list_tail
.
prev_free_block
=
self
.
fake_free_list_head
def
popleft
(
self
)
->
KVCacheBlock
:
def
popleft
(
self
)
->
KVCacheBlock
:
"""Pop the first free block and reduce num_free_blocks by 1.
"""Pop the first free block and reduce num_free_blocks by 1.
Returns:
Returns:
The first free block.
The first free block.
"""
"""
if
not
self
.
free_list_head
:
if
(
self
.
fake_free_list_head
.
next_free_block
is
self
.
fake_free_list_tail
or
self
.
fake_free_list_head
.
next_free_block
is
None
):
assert
self
.
num_free_blocks
==
0
,
(
f
"num_free_blocks (
{
self
.
num_free_blocks
}
) is out of sync "
"with the free list."
)
raise
ValueError
(
"No free blocks available"
)
raise
ValueError
(
"No free blocks available"
)
block
=
self
.
free_list_head
first_block
:
KVCacheBlock
=
self
.
fake_free_list_head
.
next_free_block
self
.
remove
(
block
)
return
block
if
first_block
.
next_free_block
is
None
:
# This should not happen if the block is from the free list.
# It indicates a bug in the caller's logic.
raise
RuntimeError
(
"Invalid block found in popleft() "
"which doesn't have a valid next_free_block"
)
# Connect fake_head and the next block of first_block (i.e. second block
# or fake tail).
self
.
fake_free_list_head
.
next_free_block
=
first_block
.
next_free_block
first_block
.
next_free_block
.
prev_free_block
=
self
.
fake_free_list_head
# Remove the block from the linked list.
first_block
.
prev_free_block
=
first_block
.
next_free_block
=
None
self
.
num_free_blocks
-=
1
return
first_block
def
remove
(
self
,
block
:
KVCacheBlock
)
->
None
:
def
remove
(
self
,
block
:
KVCacheBlock
)
->
None
:
"""Remove a block in the free list and reduce num_free_blocks by 1.
"""Remove a block in the free list and reduce num_free_blocks by 1.
...
@@ -240,19 +278,15 @@ class FreeKVCacheBlockQueue:
...
@@ -240,19 +278,15 @@ class FreeKVCacheBlockQueue:
Args:
Args:
block: The block to remove.
block: The block to remove.
"""
"""
if
block
.
prev_free_block
is
not
None
:
if
block
.
prev_free_block
is
None
or
block
.
next_free_block
is
None
:
# Link the previous block to the next block.
# This should not happen if the block is from the free list.
block
.
prev_free_block
.
next_free_block
=
block
.
next_free_block
# It indicates a bug in the caller's logic.
if
block
.
next_free_block
is
not
None
:
raise
RuntimeError
(
f
"remove() called on an invalid block:
{
block
}
"
)
# Link the next block to the previous block.
block
.
next_free_block
.
prev_free_block
=
block
.
prev_free_block
# Link the previous block to the next block.
block
.
prev_free_block
.
next_free_block
=
block
.
next_free_block
if
block
==
self
.
free_list_head
:
# Link the next block to the previous block.
# Update the head if the block is the head.
block
.
next_free_block
.
prev_free_block
=
block
.
prev_free_block
self
.
free_list_head
=
block
.
next_free_block
if
block
==
self
.
free_list_tail
:
# Update the tail if the block is the tail.
self
.
free_list_tail
=
block
.
prev_free_block
# Remove the block from the linked list.
# Remove the block from the linked list.
block
.
prev_free_block
=
block
.
next_free_block
=
None
block
.
prev_free_block
=
block
.
next_free_block
=
None
...
@@ -265,17 +299,19 @@ class FreeKVCacheBlockQueue:
...
@@ -265,17 +299,19 @@ class FreeKVCacheBlockQueue:
Args:
Args:
block: The block to append.
block: The block to append.
"""
"""
if
self
.
free_list_tail
is
not
None
:
if
self
.
fake_free_list_tail
.
prev_free_block
is
None
:
# Link the last block to the new block.
raise
RuntimeError
(
self
.
free_list_tail
.
next_free_block
=
block
"prev_free_block of fake_free_list_tail should always exist"
)
block
.
prev_free_block
=
self
.
free_list_tail
last_block
:
KVCacheBlock
=
self
.
fake_free_list_tail
.
prev_free_block
self
.
free_list_tail
=
block
else
:
# Connect the new block after the last block.
# The free list is empty.
last_block
.
next_free_block
=
block
assert
self
.
free_list_head
is
None
block
.
prev_free_block
=
last_block
self
.
free_list_head
=
self
.
free_list_tail
=
block
# Connect the fake tail after the new block.
block
.
next_free_block
=
self
.
fake_free_list_tail
self
.
fake_free_list_tail
.
prev_free_block
=
block
block
.
next_free_block
=
None
self
.
num_free_blocks
+=
1
self
.
num_free_blocks
+=
1
def
get_all_free_blocks
(
self
)
->
list
[
KVCacheBlock
]:
def
get_all_free_blocks
(
self
)
->
list
[
KVCacheBlock
]:
...
@@ -285,8 +321,14 @@ class FreeKVCacheBlockQueue:
...
@@ -285,8 +321,14 @@ class FreeKVCacheBlockQueue:
A list of free blocks.
A list of free blocks.
"""
"""
ret
=
[]
ret
=
[]
curr_block
=
self
.
free_list_head
if
self
.
fake_free_list_head
.
next_free_block
is
None
:
while
curr_block
is
not
None
:
raise
RuntimeError
(
"next_free_block of fake_free_list_head should always exist"
)
# Start from the first block
curr_block
:
KVCacheBlock
=
self
.
fake_free_list_head
.
next_free_block
# As long as next_free_block is available, we haven't reached to
# the fake tail yet.
while
curr_block
.
next_free_block
is
not
None
:
ret
.
append
(
curr_block
)
ret
.
append
(
curr_block
)
curr_block
=
curr_block
.
next_free_block
curr_block
=
curr_block
.
next_free_block
return
ret
return
ret
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment