Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3a5f0afc
Unverified
Commit
3a5f0afc
authored
Apr 01, 2025
by
Chen Zhang
Committed by
GitHub
Apr 01, 2025
Browse files
[V1] Implement sliding window attention in kv_cache_manager (#14097)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
c7e63aa4
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
662 additions
and
158 deletions
+662
-158
tests/core/block/e2e/test_correctness_sliding_window.py
tests/core/block/e2e/test_correctness_sliding_window.py
+11
-4
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+63
-70
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+12
-0
tests/v1/core/test_specialized_manager.py
tests/v1/core/test_specialized_manager.py
+138
-0
tests/v1/e2e/test_correctness_sliding_window.py
tests/v1/e2e/test_correctness_sliding_window.py
+84
-0
vllm/config.py
vllm/config.py
+1
-2
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+11
-4
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+50
-21
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+32
-3
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+5
-17
vllm/v1/core/specialized_manager.py
vllm/v1/core/specialized_manager.py
+161
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+11
-7
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+46
-11
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+19
-9
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+18
-10
No files found.
tests/core/block/e2e/test_correctness_sliding_window.py
View file @
3a5f0afc
...
...
@@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
check_answers
(
indices
,
answer
,
test_texts
)
def
prep_prompts
(
batch_size
:
int
):
def
prep_prompts
(
batch_size
:
int
,
ln_range
:
tuple
[
int
,
int
]
=
(
800
,
1100
)
):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct.
Args:
batch_size: number of prompts to generate
ln_range: an argument to control the length of the prompt
"""
prompts
:
list
[
str
]
=
[]
answer
:
list
[
int
]
=
[]
...
...
@@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
indices
.
append
(
idx
)
prompt
=
"```python
\n
# We set a number of variables, "
+
\
f
"x
{
idx
}
will be important later
\n
"
ln
=
random
.
randint
(
800
,
1100
)
ln
=
random
.
randint
(
*
ln_range
)
for
k
in
range
(
30
,
ln
):
v
=
random
.
randint
(
10
,
99
)
if
k
==
idx
:
...
...
@@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
return
prompts
,
answer
,
indices
def
check_answers
(
indices
:
list
[
int
],
answer
:
list
[
int
],
outputs
:
list
[
str
]):
def
check_answers
(
indices
:
list
[
int
],
answer
:
list
[
int
],
outputs
:
list
[
str
],
accept_rate
:
float
=
0.7
):
answer2
=
[
int
(
text
[
0
:
2
].
strip
())
for
text
in
outputs
]
print
(
list
(
zip
(
indices
,
zip
(
answer
,
answer2
))))
numok
=
0
...
...
@@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
numok
+=
1
frac_ok
=
numok
/
len
(
answer
)
print
(
f
"Num OK:
{
numok
}
/
{
len
(
answer
)
}
{
frac_ok
}
"
)
assert
frac_ok
>
0.7
assert
frac_ok
>
=
accept_rate
def
check_window
(
prompts
:
list
[
str
]):
...
...
tests/v1/core/test_prefix_caching.py
View file @
3a5f0afc
...
...
@@ -4,6 +4,7 @@
from
typing
import
Optional
import
pytest
import
torch
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -12,6 +13,8 @@ from vllm.v1.core.block_pool import BlockPool
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
KVCacheBlock
,
hash_block_tokens
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
)
def
make_request
(
request_id
,
...
...
@@ -39,13 +42,23 @@ def make_request(request_id,
)
def
make_kv_cache_config
(
block_size
:
int
,
num_blocks
:
int
)
->
KVCacheConfig
:
return
KVCacheConfig
(
num_blocks
=
num_blocks
,
tensors
=
{},
kv_cache_groups
=
[
KVCacheGroupSpec
([
'layer'
],
FullAttentionSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
False
))
],
)
@
pytest
.
mark
.
parametrize
(
"hash_algo"
,
[
"sha256"
,
"hash"
])
def
test_prefill
(
hash_algo
):
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
caching_hash_algo
=
hash_algo
,
num_preallocate_tokens
=
16
,
...
...
@@ -67,12 +80,12 @@ def test_prefill(hash_algo):
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
,
5
]
# Check full block metadata
parent_block_hash
=
None
for
block_id
in
(
0
,
1
,
2
):
block_tokens
=
tuple
(
all_token_ids
[
block_id
*
16
:
(
block_id
+
1
)
*
16
])
for
block_id
in
(
1
,
2
,
3
):
block_tokens
=
tuple
(
all_token_ids
[
(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
==
block_hash
...
...
@@ -80,7 +93,7 @@ def test_prefill(hash_algo):
parent_block_hash
=
block_hash
.
hash_value
# Check partial/preallocated block metadata
for
block_id
in
(
3
,
4
):
for
block_id
in
(
4
,
5
):
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
...
...
@@ -90,11 +103,11 @@ def test_prefill(hash_algo):
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
1
,
2
,
3
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
,
6
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
,
7
]
for
block
in
computed_blocks
:
assert
block
.
ref_cnt
==
2
...
...
@@ -107,14 +120,14 @@ def test_prefill(hash_algo):
# All blocks should be available.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
10
# The order should be
# [unallocated (
7,
8, 9)]
# [unique_req0 (
4
,
3
)]
# [unique_req1 (
6
,
5
)]
# [common (2, 1
, 0
)]
# [unallocated (8, 9
, 10
)]
# [unique_req0 (
5
,
4
)]
# [unique_req1 (
7
,
6
)]
# [common (
3,
2, 1)]
assert
[
b
.
block_id
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
7
,
8
,
9
,
4
,
3
,
6
,
5
,
2
,
1
,
0
]
]
==
[
8
,
9
,
10
,
5
,
4
,
7
,
6
,
3
,
2
,
1
]
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
...
...
@@ -122,11 +135,11 @@ def test_prefill(hash_algo):
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
1
,
2
,
3
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
7
,
8
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
8
,
9
]
# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
...
...
@@ -148,7 +161,7 @@ def test_prefill(hash_algo):
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req3
,
16
*
9
,
computed_blocks
)
# This block ID order also checks the eviction order.
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
9
,
4
,
3
,
6
,
5
,
8
,
7
,
2
,
1
,
0
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
10
,
5
,
4
,
7
,
6
,
9
,
8
,
3
,
2
,
1
]
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
0
assert
manager
.
block_pool
.
free_block_queue
.
free_list_head
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
...
...
@@ -162,10 +175,8 @@ def test_prefill_plp():
3. Schedule plp request; no hit should occur; validate blocks
'''
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
...
...
@@ -186,13 +197,13 @@ def test_prefill_plp():
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
,
5
]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
]
# Check full block metadata
parent_block_hash
=
None
for
block_id
in
(
0
,
1
,
2
):
block_tokens
=
tuple
(
all_token_ids
[
block_id
*
16
:
(
block_id
+
1
)
*
16
])
for
block_id
in
(
1
,
2
,
3
):
block_tokens
=
tuple
(
all_token_ids
[
(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
==
block_hash
...
...
@@ -200,7 +211,7 @@ def test_prefill_plp():
parent_block_hash
=
block_hash
.
hash_value
# Check partial/preallocated block metadata
for
block_id
in
(
3
,
4
):
for
block_id
in
(
4
,
5
):
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
...
...
@@ -211,11 +222,11 @@ def test_prefill_plp():
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
1
,
2
,
3
]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
,
6
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
,
7
]
for
block
in
computed_blocks
:
assert
block
.
ref_cnt
==
2
...
...
@@ -228,14 +239,14 @@ def test_prefill_plp():
# All blocks should be available.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
10
# The order should be
# [unallocated (
7,
8, 9)]
# [unique_req0 (
4
,
3
)]
# [unique_req1 (
6
,
5
)]
# [common (2, 1
, 0
)]
# [unallocated (8, 9
, 10
)]
# [unique_req0 (
5
,
4
)]
# [unique_req1 (
7
,
6
)]
# [common (
3,
2, 1)]
assert
[
b
.
block_id
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
7
,
8
,
9
,
4
,
3
,
6
,
5
,
2
,
1
,
0
]
]
==
[
8
,
9
,
10
,
5
,
4
,
7
,
6
,
3
,
2
,
1
]
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
...
...
@@ -251,7 +262,7 @@ def test_prefill_plp():
block_ids
=
[
b
.
block_id
for
b
in
blocks
]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
]
==
req0_block_hashes
assert
block_ids
!=
[
0
,
1
,
2
,
3
,
4
]
assert
block_ids
!=
[
1
,
2
,
3
,
4
,
5
]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
...
...
@@ -263,10 +274,8 @@ def test_prefill_plp():
def
test_decode
():
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
...
...
@@ -282,7 +291,7 @@ def test_decode():
assert
not
computed_blocks
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
,
5
]
# Append slots without allocating a new block.
req0
.
num_computed_tokens
=
55
...
...
@@ -316,10 +325,8 @@ def test_decode():
def
test_evict
():
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
...
...
@@ -350,15 +357,15 @@ def test_evict():
assert
[
b
.
block_id
for
b
in
manager
.
block_pool
.
free_block_queue
.
get_all_free_blocks
()
]
==
[
6
,
5
,
4
,
3
,
2
,
1
,
0
,
9
,
8
,
7
]
]
==
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
0
,
9
,
8
]
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
1
,
2
]
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
6
,
5
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
7
,
6
]
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
6
...
...
@@ -369,10 +376,8 @@ def test_hash_block_correct_reuse():
"""
block_size
=
16
manager
=
KVCacheManager
(
block_size
=
block_size
,
num_gpu_blocks
=
1
,
make_kv_cache_config
(
16
,
2
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
0
,
)
...
...
@@ -408,10 +413,8 @@ def test_computed_blocks_not_evicted():
"""
block_size
=
16
manager
=
KVCacheManager
(
block_size
=
block_size
,
num_gpu_blocks
=
2
,
make_kv_cache_config
(
block_size
,
3
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
0
,
)
...
...
@@ -424,7 +427,7 @@ def test_computed_blocks_not_evicted():
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
num_tokens
,
computed_blocks
)
assert
len
(
blocks
)
==
1
assert
blocks
[
0
].
block_id
==
0
assert
blocks
[
0
].
block_id
==
1
# Allocate another block.
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)))
...
...
@@ -433,7 +436,7 @@ def test_computed_blocks_not_evicted():
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
num_tokens
,
computed_blocks
)
assert
len
(
blocks
)
==
1
assert
blocks
[
0
].
block_id
==
1
assert
blocks
[
0
].
block_id
==
2
# Free the blocks.
manager
.
free
(
req0
)
...
...
@@ -444,13 +447,13 @@ def test_computed_blocks_not_evicted():
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
)
==
1
assert
computed_blocks
[
0
].
block_id
==
0
assert
computed_blocks
[
0
].
block_id
==
1
assert
num_computed_tokens
==
block_size
blocks
=
manager
.
allocate_slots
(
req2
,
num_tokens
*
2
-
num_tokens
,
computed_blocks
)
assert
len
(
blocks
)
==
1
assert
blocks
[
0
].
block_id
==
1
assert
blocks
[
0
].
block_id
==
2
def
test_basic_prefix_caching_disabled
():
...
...
@@ -459,10 +462,8 @@ def test_basic_prefix_caching_disabled():
"""
block_size
=
4
manager
=
KVCacheManager
(
block_size
=
block_size
,
num_gpu_blocks
=
4
,
make_kv_cache_config
(
block_size
,
5
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
False
,
num_preallocate_tokens
=
0
,
)
...
...
@@ -502,10 +503,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
This tests that the preallocated blocks are correctly added.
"""
manager
=
KVCacheManager
(
block_size
=
block_size
,
num_gpu_blocks
=
10
,
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
num_preallocate_tokens
,
)
...
...
@@ -586,10 +585,8 @@ def test_mm_prefix_caching():
This tests that the multi-modal prefix caching is correct.
"""
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
...
...
@@ -629,7 +626,7 @@ def test_mm_prefix_caching():
assert
block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
,
5
]
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
...
...
@@ -667,10 +664,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
"""
block_size
=
16
manager
=
KVCacheManager
(
block_size
=
block_size
,
num_gpu_blocks
=
10
,
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
0
,
)
...
...
@@ -723,10 +718,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
def
test_reset_prefix_cache
():
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
make_kv_cache_config
(
16
,
11
),
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
0
,
)
...
...
@@ -736,7 +729,7 @@ def test_reset_prefix_cache():
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
1
,
2
,
3
,
4
]
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
...
...
@@ -745,7 +738,7 @@ def test_reset_prefix_cache():
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
len
(
computed_blocks
)
==
3
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
5
]
# Failed to reset prefix cache because some blocks are not freed yet.
assert
not
manager
.
reset_prefix_cache
()
...
...
tests/v1/core/test_scheduler.py
View file @
3a5f0afc
...
...
@@ -2,12 +2,15 @@
from
typing
import
Optional
import
pytest
import
torch
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
)
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.structured_output
import
StructuredOutputManager
...
...
@@ -66,12 +69,21 @@ def create_scheduler(
model_config
=
model_config
,
cache_config
=
cache_config
,
)
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
10000
,
# A large number of blocks to hold all requests
tensors
=
{},
kv_cache_groups
=
[
KVCacheGroupSpec
([
'layer'
],
FullAttentionSpec
(
16
,
1
,
1
,
torch
.
float32
,
False
))
],
)
cache_config
.
num_gpu_blocks
=
10000
return
Scheduler
(
scheduler_config
,
model_config
,
cache_config
,
lora_config
=
None
,
kv_cache_config
=
kv_cache_config
,
log_stats
=
True
,
structured_output_manager
=
StructuredOutputManager
(
vllm_config
),
)
...
...
tests/v1/core/test_specialized_manager.py
0 → 100644
View file @
3a5f0afc
# SPDX-License-Identifier: Apache-2.0
import
torch
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHashType
,
KVCacheBlock
from
vllm.v1.core.specialized_manager
import
SlidingWindowManager
from
vllm.v1.kv_cache_interface
import
SlidingWindowSpec
def
test_sliding_window_possible_cached_prefix
():
sliding_window_spec
=
SlidingWindowSpec
(
block_size
=
2
,
num_kv_heads
=
1
,
head_size
=
1
,
dtype
=
torch
.
float32
,
sliding_window
=
4
,
use_mla
=
False
,
)
block_pool
=
BlockPool
(
num_gpu_blocks
=
100
,
enable_caching
=
True
)
manager
=
SlidingWindowManager
(
sliding_window_spec
,
block_pool
)
def
run_one_case
(
block_is_cached
,
expect_length
):
block_hash_list
=
[
BlockHashType
(
i
,
())
for
i
in
range
(
len
(
block_is_cached
))
]
block_pool
.
cached_block_hash_to_block
.
clear
()
# Mock the block pool with the cached blocks
for
i
,
(
block_hash
,
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
if
is_cached
:
block_pool
.
cached_block_hash_to_block
[
block_hash
]
=
{
i
:
block_pool
.
blocks
[
i
+
10
]
}
computed_blocks
=
manager
.
find_longest_cache_hit
(
block_hash_list
)
assert
len
(
computed_blocks
)
==
expect_length
assert
all
(
block
==
block_pool
.
null_block
for
block
in
computed_blocks
[:
expect_length
-
2
])
for
i
in
range
(
2
):
if
i
<
expect_length
:
block_index
=
expect_length
-
i
-
1
assert
computed_blocks
[
block_index
].
block_id
==
block_index
+
10
run_one_case
([
False
]
*
10
,
0
)
run_one_case
([
True
],
1
)
run_one_case
([
True
,
False
],
1
)
run_one_case
([
True
,
True
],
2
)
run_one_case
([
True
,
True
,
False
],
2
)
run_one_case
([
True
,
True
,
True
],
3
)
run_one_case
([
True
,
True
,
True
,
False
],
3
)
run_one_case
([
True
,
True
,
False
,
True
,
False
,
False
,
True
,
True
,
False
,
True
,
True
,
True
],
12
)
run_one_case
([
True
,
True
,
False
,
True
,
False
,
False
,
True
,
True
,
False
,
False
,
False
],
8
)
run_one_case
([
True
,
True
,
False
,
True
,
False
,
False
,
True
,
True
,
False
,
False
,
False
,
True
],
8
)
def
test_sliding_window_remove_skipped_blocks
():
sliding_window_spec
=
SlidingWindowSpec
(
block_size
=
2
,
num_kv_heads
=
1
,
head_size
=
1
,
dtype
=
torch
.
float32
,
sliding_window
=
4
,
use_mla
=
False
,
)
block_pool
=
BlockPool
(
num_gpu_blocks
=
2000
,
enable_caching
=
True
)
manager
=
SlidingWindowManager
(
sliding_window_spec
,
block_pool
)
null_block_id
=
block_pool
.
null_block
.
block_id
def
id_to_block_table
(
ids
):
return
[
KVCacheBlock
(
id_
)
if
id_
!=
null_block_id
else
block_pool
.
null_block
for
id_
in
ids
]
def
assert_block_id
(
block_table
,
ids
):
for
block
,
id_
in
zip
(
block_table
,
ids
):
if
id_
==
null_block_id
:
assert
block
==
block_pool
.
null_block
else
:
assert
block
.
block_id
==
id_
original_block_ids
=
[
1000
,
1001
,
1002
,
1003
,
1004
,
1005
,
1006
,
1007
,
1008
,
1009
,
1010
]
block_table
=
id_to_block_table
(
original_block_ids
)
removed
=
manager
.
remove_skipped_blocks
(
block_table
,
0
)
assert_block_id
(
removed
,
[])
assert_block_id
(
block_table
,
original_block_ids
)
# 4 tokens are computed. Only token 0 is out of the sliding window. As
# block 1000 also contains token 1 that is in the sliding window, block 1000
# cannot be removed.
removed
=
manager
.
remove_skipped_blocks
(
block_table
,
4
)
assert_block_id
(
removed
,
[])
assert_block_id
(
block_table
,
original_block_ids
)
# 5 tokens are computed. Token 0 & 1 are out of the sliding window.
# Block 1000 can be removed.
removed
=
manager
.
remove_skipped_blocks
(
block_table
,
5
)
assert_block_id
(
removed
,
[
original_block_ids
[
0
]])
assert_block_id
(
block_table
,
[
null_block_id
]
+
original_block_ids
[
1
:])
# 6 tokens are computed. Token 0-2 are out of the sliding window.
# Cannot remove new block as the block 1001 is still used by token 3.
removed
=
manager
.
remove_skipped_blocks
(
block_table
,
6
)
assert_block_id
(
removed
,
[])
assert_block_id
(
block_table
,
[
null_block_id
]
+
original_block_ids
[
1
:])
# 7 tokens are computed. Token 0-3 are out of the sliding window.
# Block 1001 can be removed and block 1000 is already removed.
removed
=
manager
.
remove_skipped_blocks
(
block_table
,
7
)
assert_block_id
(
removed
,
[
original_block_ids
[
1
]])
assert_block_id
(
block_table
,
[
null_block_id
]
*
2
+
original_block_ids
[
2
:])
# 11 tokens are computed. Token 0-7 are out of the sliding window.
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
# sequence, and is expected to be evicted earlier than 1002, so the order
# of removed blocks should be [1003, 1002].
removed
=
manager
.
remove_skipped_blocks
(
block_table
,
11
)
assert_block_id
(
removed
,
[
original_block_ids
[
3
],
original_block_ids
[
2
]])
assert_block_id
(
block_table
,
[
null_block_id
]
*
4
+
original_block_ids
[
4
:])
tests/v1/e2e/test_correctness_sliding_window.py
0 → 100644
View file @
3a5f0afc
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
pytest
from
vllm
import
LLM
,
SamplingParams
from
...core.block.e2e.test_correctness_sliding_window
import
(
check_answers
,
prep_prompts
)
@
dataclass
class
TestConfig
:
sliding_window
:
int
ln_range
:
tuple
[
int
,
int
]
model_config
=
{
"bigcode/starcoder2-3b"
:
TestConfig
(
4096
,
(
800
,
1100
)),
"google/gemma-2-2b-it"
:
TestConfig
(
4096
,
(
400
,
800
)),
}
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"bigcode/starcoder2-3b"
,
# sliding window only
"google/gemma-2-2b-it"
,
# sliding window + full attention
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_sliding_window_retrival
(
monkeypatch
,
model
,
batch_size
,
seed
):
"""
The test does a bunch of assignments "x1 = 10
\n
x2 = 33
\n
..." and then
asks for value of one of them (which is outside the sliding window).
If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly).
"""
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
test_config
=
model_config
[
model
]
llm
=
LLM
(
model
=
model
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
)
prompts
,
answer
,
indices
=
prep_prompts
(
batch_size
,
ln_range
=
test_config
.
ln_range
)
check_length
(
prompts
,
llm
,
test_config
.
sliding_window
)
# Fresh generation
responses
=
llm
.
generate
(
prompts
,
sampling_params
)
check_answers
(
indices
,
answer
,
[
response
.
outputs
[
0
].
text
for
response
in
responses
],
accept_rate
=
1.0
)
# Re-generate with the same prompts to test prefix caching
responses
=
llm
.
generate
(
prompts
,
sampling_params
)
check_answers
(
indices
,
answer
,
[
response
.
outputs
[
0
].
text
for
response
in
responses
],
accept_rate
=
1.0
)
def
check_length
(
prompts
:
list
[
str
],
llm
:
LLM
,
sliding_window
:
int
):
"""
Check if the prompt length is valid, i.e., longer than the sliding window
size and shorter than the model's max length.
Args:
prompts: list of prompts
llm: LLM object
sliding_window: Sliding window size
"""
tokenizer
=
llm
.
get_tokenizer
()
max_model_len
=
llm
.
llm_engine
.
model_config
.
max_model_len
assert
any
(
len
(
tokenizer
.
encode
(
prompt
))
>
sliding_window
for
prompt
in
prompts
),
"Prompt is too short for test"
assert
all
(
len
(
tokenizer
.
encode
(
prompt
))
<=
max_model_len
for
prompt
in
prompts
),
"Prompt is too long for test"
vllm/config.py
View file @
3a5f0afc
...
...
@@ -1116,8 +1116,7 @@ class CacheConfig:
is_attention_free: Whether the model is attention-free.
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
sliding_window: Sliding window size for the KV cache. Can not work with
prefix caching enabled.
sliding_window: Sliding window size for the KV cache.
enable_prefix_caching: Whether to enable prefix caching.
cpu_offload_gb: Size of the CPU offload buffer in GiB.
"""
...
...
vllm/v1/core/block_pool.py
View file @
3a5f0afc
...
...
@@ -27,6 +27,7 @@ class BlockPool:
"""
def
__init__
(
self
,
num_gpu_blocks
:
int
,
enable_caching
:
bool
):
assert
isinstance
(
num_gpu_blocks
,
int
)
and
num_gpu_blocks
>
0
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
enable_caching
=
enable_caching
# All kv-cache blocks.
...
...
@@ -50,6 +51,11 @@ class BlockPool:
self
.
cached_block_hash_to_block
:
dict
[
BlockHashType
,
dict
[
int
,
KVCacheBlock
]]
=
defaultdict
(
dict
)
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self
.
null_block
=
self
.
free_block_queue
.
popleft
()
def
get_cached_block
(
self
,
block_hash
:
BlockHashType
)
->
Optional
[
KVCacheBlock
]:
"""Get a cached block by the block hash, or None if cache miss.
...
...
@@ -214,7 +220,7 @@ class BlockPool:
for
block
in
blocks
:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if
block
.
ref_cnt
==
0
:
if
block
.
ref_cnt
==
0
and
block
!=
self
.
null_block
:
self
.
free_block_queue
.
remove
(
block
)
block
.
incr_ref
()
...
...
@@ -228,7 +234,8 @@ class BlockPool:
"""
for
block
in
ordered_blocks
:
block
.
decr_ref
()
if
block
.
ref_cnt
==
0
:
# null_block should not be added to the free list.
if
block
.
ref_cnt
==
0
and
block
!=
self
.
null_block
:
self
.
free_block_queue
.
append
(
block
)
def
reset_prefix_cache
(
self
)
->
bool
:
...
...
@@ -241,10 +248,10 @@ class BlockPool:
False otherwise.
"""
num_used_blocks
=
(
self
.
num_gpu_blocks
-
self
.
get_num_free_blocks
())
if
num_used_blocks
>
0
:
if
num_used_blocks
!=
1
:
# The null block is always marked as used
logger
.
warning
(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet"
,
num_used_blocks
)
"blocks (%d) are not freed yet"
,
num_used_blocks
-
1
)
return
False
# Remove all hashes so that no new blocks will hit.
...
...
vllm/v1/core/kv_cache_manager.py
View file @
3a5f0afc
...
...
@@ -9,6 +9,8 @@ from vllm.utils import cdiv, sha256
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
KVCacheBlock
,
hash_request_tokens
)
from
vllm.v1.core.specialized_manager
import
get_specialized_manager
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
,
RequestStatus
...
...
@@ -19,20 +21,22 @@ class KVCacheManager:
def
__init__
(
self
,
block_size
:
int
,
num_gpu_blocks
:
int
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
sliding_window
:
Optional
[
int
]
=
None
,
enable_caching
:
bool
=
True
,
caching_hash_algo
:
str
=
"builtin"
,
num_preallocate_tokens
:
int
=
64
,
log_stats
:
bool
=
False
,
)
->
None
:
self
.
block_size
=
block_size
self
.
num_gpu_blocks
=
num_gpu_blocks
assert
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
,
(
"KVCacheManager does not support hybrid models with more than 1 "
"kv cache group"
)
kv_cache_spec
=
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
num_gpu_blocks
=
kv_cache_config
.
num_blocks
self
.
max_model_len
=
max_model_len
self
.
max_num_blocks_per_req
=
cdiv
(
max_model_len
,
block_size
)
self
.
sliding_window
=
sliding_window
self
.
max_num_blocks_per_req
=
cdiv
(
max_model_len
,
self
.
block_size
)
self
.
enable_caching
=
enable_caching
self
.
caching_hash_fn
=
sha256
if
caching_hash_algo
==
"sha256"
else
hash
# FIXME: make prefix cache stats conditional on log_stats
...
...
@@ -48,9 +52,15 @@ class KVCacheManager:
# further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks.
self
.
num_preallocate_tokens
=
num_preallocate_tokens
self
.
num_preallocate_blocks
=
cdiv
(
num_preallocate_tokens
,
block_size
)
self
.
num_preallocate_blocks
=
cdiv
(
num_preallocate_tokens
,
self
.
block_size
)
self
.
block_pool
=
BlockPool
(
self
.
num_gpu_blocks
,
enable_caching
)
self
.
block_pool
=
BlockPool
(
num_gpu_blocks
,
enable_caching
)
self
.
specialized_manager
=
get_specialized_manager
(
kv_cache_spec
=
kv_cache_spec
,
block_pool
=
self
.
block_pool
,
)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
...
...
@@ -117,17 +127,25 @@ class KVCacheManager:
self
.
prefix_cache_stats
.
requests
+=
1
if
request
.
sampling_params
.
prompt_logprobs
is
None
:
# Check for cache hits
computed_blocks
=
[]
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash
# is not in the cached_block_hash_to_id, the following
# block hashes are not computed yet for sure.
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
computed_blocks
.
append
(
cached_block
)
if
len
(
block_hashes
)
*
self
.
block_size
==
request
.
num_tokens
:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash
=
block_hashes
.
pop
()
else
:
break
last_block_hash
=
None
computed_blocks
=
(
self
.
specialized_manager
.
find_longest_cache_hit
(
block_hashes
))
if
last_block_hash
is
not
None
:
# Add back the last block hash if it was removed.
block_hashes
.
append
(
last_block_hash
)
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
...
...
@@ -176,13 +194,24 @@ class KVCacheManager:
new_computed_blocks
=
new_computed_blocks
or
[]
req_blocks
=
self
.
req_to_blocks
[
request
.
request_id
]
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
removed_blocks
=
self
.
specialized_manager
.
remove_skipped_blocks
(
req_blocks
,
request
.
num_computed_tokens
)
self
.
block_pool
.
free_blocks
(
removed_blocks
)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens
=
(
request
.
num_computed_tokens
+
len
(
new_computed_blocks
)
*
self
.
block_size
)
num_required_blocks
=
cdiv
(
num_computed_tokens
+
num_tokens
,
self
.
block_size
)
req_blocks
=
self
.
req_to_blocks
[
request
.
request_id
]
num_new_blocks
=
(
num_required_blocks
-
len
(
req_blocks
)
-
len
(
new_computed_blocks
))
...
...
vllm/v1/core/kv_cache_utils.py
View file @
3a5f0afc
...
...
@@ -9,8 +9,9 @@ from typing import Any, Callable, NamedTuple, Optional
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
sha256
from
vllm.v1.kv_cache_interface
import
(
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheSpec
,
KVCacheTensor
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
...
...
@@ -483,7 +484,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
max_model_len
=
vllm_config
.
model_config
.
max_model_len
needed_memory
=
0
for
layer_spec
in
kv_cache_spec
.
values
():
needed_memory
+=
layer_spec
.
bytes_for_tokens
(
max_model_len
)
needed_memory
+=
layer_spec
.
max_memory_usage_bytes
(
vllm_config
)
if
needed_memory
>
available_memory
:
raise
ValueError
(
...
...
@@ -597,6 +598,33 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return
kv_cache_config
def
unify_hybrid_kv_cache_specs
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]):
"""
Only models with one type of KV cache are supported yet. This function tries
to convert the KV cache specs to one type if the model is a hybrid model
with multiple type of KV cache. It will convert all SlidingWindowSpec to
FullAttentionSpec if both types are present.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
has_full_attention
=
any
(
isinstance
(
spec
,
FullAttentionSpec
)
for
spec
in
kv_cache_spec
.
values
())
has_sliding_window
=
any
(
isinstance
(
spec
,
SlidingWindowSpec
)
for
spec
in
kv_cache_spec
.
values
())
if
has_full_attention
and
has_sliding_window
:
for
layer_name
,
spec
in
kv_cache_spec
.
items
():
if
isinstance
(
spec
,
SlidingWindowSpec
):
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
spec
.
block_size
,
num_kv_heads
=
spec
.
num_kv_heads
,
head_size
=
spec
.
head_size
,
dtype
=
spec
.
dtype
,
use_mla
=
spec
.
use_mla
,
)
def
get_kv_cache_config
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
KVCacheConfig
:
...
...
@@ -613,6 +641,7 @@ def get_kv_cache_config(vllm_config: VllmConfig,
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
unify_hybrid_kv_cache_specs
(
kv_cache_spec
)
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
...
...
vllm/v1/core/sched/scheduler.py
View file @
3a5f0afc
...
...
@@ -19,6 +19,7 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
from
vllm.v1.core.sched.utils
import
check_stop
from
vllm.v1.engine
import
(
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
...
...
@@ -35,6 +36,7 @@ class Scheduler(SchedulerInterface):
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_config
:
KVCacheConfig
,
structured_output_manager
:
StructuredOutputManager
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
include_finished_set
:
bool
=
False
,
...
...
@@ -43,6 +45,7 @@ class Scheduler(SchedulerInterface):
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
kv_cache_config
=
kv_cache_config
self
.
log_stats
=
log_stats
self
.
structured_output_manager
=
structured_output_manager
...
...
@@ -58,15 +61,11 @@ class Scheduler(SchedulerInterface):
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_model_len
=
self
.
scheduler_config
.
max_model_len
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
assert
isinstance
(
num_gpu_blocks
,
int
)
and
num_gpu_blocks
>
0
# Create the KV cache manager.
self
.
kv_cache_manager
=
KVCacheManager
(
block_size
=
self
.
cache_config
.
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
kv_cache_config
=
kv_cache_config
,
max_model_len
=
self
.
max_model_len
,
sliding_window
=
self
.
cache_config
.
sliding_window
,
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
,
enable_caching
=
cache_config
.
enable_prefix_caching
,
caching_hash_algo
=
self
.
cache_config
.
prefix_caching_hash_algo
,
log_stats
=
self
.
log_stats
)
self
.
block_size
=
self
.
cache_config
.
block_size
...
...
@@ -300,17 +299,6 @@ class Scheduler(SchedulerInterface):
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens
=
request
.
num_tokens
-
num_computed_tokens
if
num_new_tokens
==
0
:
# This happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens
# is always a multiple of the block size. This limitation
# can potentially be removed in the future to slightly
# improve the performance.
num_computed_tokens
-=
self
.
block_size
num_new_tokens
=
self
.
block_size
computed_blocks
.
pop
()
if
(
0
<
self
.
scheduler_config
.
long_prefill_token_threshold
<
num_new_tokens
):
num_new_tokens
=
(
...
...
vllm/v1/core/specialized_manager.py
0 → 100644
View file @
3a5f0afc
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
vllm.utils
import
cdiv
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHashType
,
KVCacheBlock
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheSpec
,
SlidingWindowSpec
)
class
SpecializedManager
(
ABC
):
"""
An abstract base class for specialized managers that handle the kv
cache management logic of different attention layers.
"""
def
__init__
(
self
,
kv_cache_spec
:
KVCacheSpec
,
block_pool
:
BlockPool
,
)
->
None
:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
"""
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_pool
=
block_pool
@
abstractmethod
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHashType
])
->
list
[
KVCacheBlock
]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
Args:
block_hashes: The block hashes of the request.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
"""
raise
NotImplementedError
@
abstractmethod
def
remove_skipped_blocks
(
self
,
blocks
:
list
[
KVCacheBlock
],
num_computed_tokens
:
int
)
->
list
[
KVCacheBlock
]:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function.
Args:
blocks: The list of blocks to be updated.
num_computed_tokens: The number of tokens that have been computed.
Returns:
The removed blocks in eviction order.
"""
raise
NotImplementedError
class
FullAttentionManager
(
SpecializedManager
):
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHashType
])
->
list
[
KVCacheBlock
]:
computed_blocks
:
list
[
KVCacheBlock
]
=
[]
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
computed_blocks
.
append
(
cached_block
)
else
:
break
return
computed_blocks
def
remove_skipped_blocks
(
self
,
blocks
:
list
[
KVCacheBlock
],
num_computed_tokens
:
int
)
->
list
[
KVCacheBlock
]:
# No need to remove blocks for full attention.
return
[]
class
SlidingWindowManager
(
SpecializedManager
):
def
__init__
(
self
,
kv_cache_spec
:
SlidingWindowSpec
,
block_pool
:
BlockPool
):
super
().
__init__
(
kv_cache_spec
,
block_pool
)
self
.
sliding_window
=
kv_cache_spec
.
sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self
.
sliding_window_contiguous_blocks
=
cdiv
(
(
kv_cache_spec
.
sliding_window
-
1
),
self
.
block_size
)
self
.
_null_block
=
block_pool
.
null_block
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHashType
])
->
list
[
KVCacheBlock
]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks
=
[
self
.
_null_block
]
*
len
(
block_hashes
)
num_contiguous_blocks
=
0
# Search from right to left and early stop when a match is found.
for
i
in
range
(
len
(
block_hashes
)
-
1
,
-
1
,
-
1
):
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hashes
[
i
]):
computed_blocks
[
i
]
=
cached_block
num_contiguous_blocks
+=
1
if
(
num_contiguous_blocks
>=
self
.
sliding_window_contiguous_blocks
):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del
computed_blocks
[
i
+
num_contiguous_blocks
:]
return
computed_blocks
else
:
num_contiguous_blocks
=
0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del
computed_blocks
[
num_contiguous_blocks
:]
return
computed_blocks
def
remove_skipped_blocks
(
self
,
blocks
:
list
[
KVCacheBlock
],
num_computed_tokens
:
int
)
->
list
[
KVCacheBlock
]:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token
=
num_computed_tokens
-
self
.
sliding_window
+
1
last_useful_block
=
last_useful_token
//
self
.
block_size
removed_blocks
:
list
[
KVCacheBlock
]
=
[]
for
i
in
range
(
last_useful_block
-
1
,
-
1
,
-
1
):
if
blocks
[
i
]
==
self
.
_null_block
:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks
.
append
(
blocks
[
i
])
blocks
[
i
]
=
self
.
_null_block
return
removed_blocks
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SpecializedManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
}
def
get_specialized_manager
(
kv_cache_spec
:
KVCacheSpec
,
block_pool
:
BlockPool
)
->
SpecializedManager
:
manager_class
=
spec_manager_map
[
type
(
kv_cache_spec
)]
manager
=
manager_class
(
kv_cache_spec
,
block_pool
)
return
manager
vllm/v1/engine/core.py
View file @
3a5f0afc
...
...
@@ -33,6 +33,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
...
...
@@ -66,8 +67,9 @@ class EngineCore:
self
.
model_executor
=
executor_class
(
vllm_config
)
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks
,
num_cpu_blocks
=
self
.
_initialize_kv_caches
(
vllm_config
)
num_gpu_blocks
,
num_cpu_blocks
,
kv_cache_config
=
\
self
.
_initialize_kv_caches
(
vllm_config
)
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
vllm_config
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
...
...
@@ -95,10 +97,11 @@ class EngineCore:
model_config
=
vllm_config
.
model_config
,
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
kv_cache_config
=
kv_cache_config
,
structured_output_manager
=
self
.
structured_output_manager
,
include_finished_set
=
vllm_config
.
parallel_config
.
data_parallel_size
>
1
,
log_stats
=
self
.
log_stats
,
structured_output_manager
=
self
.
structured_output_manager
,
)
# Setup MM Input Mapper.
...
...
@@ -117,8 +120,8 @@ class EngineCore:
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
]:
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
,
KVCacheConfig
]:
start
=
time
.
time
()
# Get all kv cache needed by the model
...
...
@@ -143,13 +146,14 @@ class EngineCore:
unify_kv_cache_configs
(
kv_cache_configs
)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to
get the number of blocks
.
# an arbitrary one to
initialize the scheduler
.
assert
all
([
cfg
.
num_blocks
==
kv_cache_configs
[
0
].
num_blocks
for
cfg
in
kv_cache_configs
])
num_gpu_blocks
=
kv_cache_configs
[
0
].
num_blocks
num_cpu_blocks
=
0
scheduler_kv_cache_config
=
kv_cache_configs
[
0
]
# Initialize kv cache and warmup the execution
self
.
model_executor
.
initialize_from_config
(
kv_cache_configs
)
...
...
@@ -157,7 +161,7 @@ class EngineCore:
elapsed
=
time
.
time
()
-
start
logger
.
info
((
"init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"
),
elapsed
)
return
num_gpu_blocks
,
num_cpu_blocks
return
num_gpu_blocks
,
num_cpu_blocks
,
scheduler_kv_cache_config
def
add_request
(
self
,
request
:
EngineCoreRequest
):
"""Add request to the scheduler."""
...
...
vllm/v1/kv_cache_interface.py
View file @
3a5f0afc
...
...
@@ -4,6 +4,7 @@ from dataclasses import dataclass
import
torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
,
get_dtype_size
...
...
@@ -43,28 +44,23 @@ class KVCacheSpec:
"""
raise
NotImplementedError
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
"""
The KV cache size for `num_tokens` tokens in bytes. Returns the real
memory size after padding `num_tokens` to full blocks.
The maximum possible memory usage of this KV cache in bytes.
Returns:
The KV cache size
The KV cache size
in bytes
"""
raise
NotImplementedError
@
dataclass
class
Full
AttentionSpec
(
KVCacheSpec
):
class
AttentionSpec
(
KVCacheSpec
):
num_kv_heads
:
int
head_size
:
int
dtype
:
torch
.
dtype
use_mla
:
bool
@
property
def
type_id
(
self
)
->
str
:
return
f
"full_attention_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
@
property
def
page_size_bytes
(
self
)
->
int
:
# For MLA we only store a single latent vector
...
...
@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
return
coef
*
self
.
block_size
*
self
.
num_kv_heads
*
self
.
head_size
\
*
get_dtype_size
(
self
.
dtype
)
def
bytes_for_tokens
(
self
,
num_tokens
:
int
)
->
int
:
return
cdiv
(
num_tokens
,
self
.
block_size
)
*
self
.
page_size_bytes
@
dataclass
class
FullAttentionSpec
(
AttentionSpec
):
@
property
def
type_id
(
self
)
->
str
:
return
f
"full_attention_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
return
cdiv
(
max_model_len
,
self
.
block_size
)
*
self
.
page_size_bytes
@
dataclass
class
SlidingWindowSpec
(
AttentionSpec
):
sliding_window
:
int
def
__post_init__
(
self
):
assert
not
self
.
use_mla
,
"MLA is not supported for sliding window"
@
property
def
type_id
(
self
)
->
str
:
return
f
"sliding_window_
{
self
.
sliding_window
}
_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
# noqa
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
max_num_batched_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
)
# During chunked prefill, we allocate KV cache for the last
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens
=
min
(
self
.
sliding_window
-
1
+
max_num_batched_tokens
,
max_model_len
)
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return
(
cdiv
(
num_tokens
,
self
.
block_size
)
+
1
)
*
self
.
page_size_bytes
@
dataclass
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
3a5f0afc
...
...
@@ -28,8 +28,9 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
check_use_alibi
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
...
@@ -1572,7 +1573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert
num_blocks
>=
kv_cache_config
.
num_blocks
if
isinstance
(
kv_cache_spec
,
Full
AttentionSpec
):
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
...
...
@@ -1611,6 +1612,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# cross-attention
assert
isinstance
(
attn_module
,
Attention
)
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
sliding_window
=
attn_module
.
sliding_window
,
use_mla
=
use_mla
)
else
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
3a5f0afc
...
...
@@ -29,7 +29,7 @@ from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
PallasMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
,
SamplerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
...
...
@@ -353,10 +353,18 @@ class TPUModelRunner:
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert
isinstance
(
attn_module
,
Attention
)
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
attn_module
.
dtype
,
sliding_window
=
attn_module
.
sliding_window
,
use_mla
=
False
,
)
else
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment