Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f8a1a2d1
Unverified
Commit
f8a1a2d1
authored
Jun 06, 2025
by
Chen Zhang
Committed by
GitHub
Jun 05, 2025
Browse files
[v1] Hybrid Memory Allocator (#17996)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
3465b87e
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1597 additions
and
438 deletions
+1597
-438
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+210
-35
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+350
-100
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+8
-8
tests/v1/core/test_specialized_manager.py
tests/v1/core/test_specialized_manager.py
+16
-11
tests/v1/e2e/test_correctness_sliding_window.py
tests/v1/e2e/test_correctness_sliding_window.py
+2
-2
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+2
-2
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+2
-2
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+12
-12
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+5
-5
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+14
-13
vllm/config.py
vllm/config.py
+21
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+8
-0
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+44
-25
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+358
-0
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+76
-59
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+271
-51
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+4
-3
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+98
-53
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+9
-24
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+87
-33
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
f8a1a2d1
...
@@ -15,8 +15,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
...
@@ -15,8 +15,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
from
vllm.v1.core.kv_cache_utils
import
(
from
vllm.v1.core.kv_cache_utils
import
(
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
get_max_concurrency_for_kv_cache_config
,
hash_block_tokens
,
get_kv_cache_config
,
get_max_concurrency_for_kv_cache_config
,
hash_request_tokens
,
unify_kv_cache_configs
)
hash_block_tokens
,
hash_request_tokens
,
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
,
KVCacheGroupSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
SlidingWindowSpec
)
...
@@ -63,6 +63,20 @@ def new_kv_cache_spec(block_size=16,
...
@@ -63,6 +63,20 @@ def new_kv_cache_spec(block_size=16,
sliding_window
=
sliding_window
)
sliding_window
=
sliding_window
)
def
new_sliding_window_spec
(
block_size
=
16
,
num_kv_heads
=
2
,
head_size
=
64
,
dtype
=
torch
.
float32
,
use_mla
=
False
,
sliding_window
=
1
):
return
SlidingWindowSpec
(
block_size
=
block_size
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
dtype
=
dtype
,
use_mla
=
use_mla
,
sliding_window
=
sliding_window
)
def
test_none_hash
(
monkeypatch
):
def
test_none_hash
(
monkeypatch
):
import
vllm.v1.core.kv_cache_utils
import
vllm.v1.core.kv_cache_utils
...
@@ -403,10 +417,10 @@ def test_unify_kv_cache_configs():
...
@@ -403,10 +417,10 @@ def test_unify_kv_cache_configs():
same_kv_cache_config
=
[
same_kv_cache_config
=
[
KVCacheConfig
(
KVCacheConfig
(
num_blocks
=
10
,
num_blocks
=
10
,
tensors
=
{
kv_cache_
tensors
=
[
"layer1"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer1"
]
),
"layer2"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer2"
]
),
}
,
]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
KVCacheGroupSpec
([
"layer2"
],
...
@@ -415,10 +429,10 @@ def test_unify_kv_cache_configs():
...
@@ -415,10 +429,10 @@ def test_unify_kv_cache_configs():
),
),
KVCacheConfig
(
KVCacheConfig
(
num_blocks
=
20
,
num_blocks
=
20
,
tensors
=
{
kv_cache_
tensors
=
[
"layer1"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer1"
]
),
"layer2"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer2"
]
),
}
,
]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
KVCacheGroupSpec
([
"layer2"
],
...
@@ -433,10 +447,10 @@ def test_unify_kv_cache_configs():
...
@@ -433,10 +447,10 @@ def test_unify_kv_cache_configs():
need_sort_kv_cache_config
=
[
need_sort_kv_cache_config
=
[
KVCacheConfig
(
KVCacheConfig
(
num_blocks
=
10
,
num_blocks
=
10
,
tensors
=
{
kv_cache_
tensors
=
[
"layer1"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer1"
]
),
"layer2"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer2"
]
),
}
,
]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
KVCacheGroupSpec
([
"layer2"
],
...
@@ -445,10 +459,10 @@ def test_unify_kv_cache_configs():
...
@@ -445,10 +459,10 @@ def test_unify_kv_cache_configs():
),
),
KVCacheConfig
(
KVCacheConfig
(
num_blocks
=
20
,
num_blocks
=
20
,
tensors
=
{
kv_cache_
tensors
=
[
"layer1"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer1"
]
),
"layer2"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer2"
]
),
}
,
]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer2"
],
KVCacheGroupSpec
([
"layer2"
],
new_kv_cache_spec
(
num_kv_heads
=
4
)),
new_kv_cache_spec
(
num_kv_heads
=
4
)),
...
@@ -464,10 +478,10 @@ def test_unify_kv_cache_configs():
...
@@ -464,10 +478,10 @@ def test_unify_kv_cache_configs():
diff_kv_cache_config
=
[
diff_kv_cache_config
=
[
KVCacheConfig
(
KVCacheConfig
(
num_blocks
=
10
,
num_blocks
=
10
,
tensors
=
{
kv_cache_
tensors
=
[
"layer1"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer1"
]
),
"layer2"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer2"
]
),
}
,
]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
KVCacheGroupSpec
([
"layer2"
],
...
@@ -476,10 +490,10 @@ def test_unify_kv_cache_configs():
...
@@ -476,10 +490,10 @@ def test_unify_kv_cache_configs():
),
),
KVCacheConfig
(
KVCacheConfig
(
num_blocks
=
20
,
num_blocks
=
20
,
tensors
=
{
kv_cache_
tensors
=
[
"layer1"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer1"
]
),
"layer2"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer2"
]
),
}
,
]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer2"
],
KVCacheGroupSpec
([
"layer2"
],
...
@@ -636,7 +650,7 @@ def test_get_max_concurrency_for_kv_cache_config():
...
@@ -636,7 +650,7 @@ def test_get_max_concurrency_for_kv_cache_config():
kv_cache_config_full_attention
=
KVCacheConfig
(
kv_cache_config_full_attention
=
KVCacheConfig
(
num_blocks
=
int
(
1024
*
1.5
),
num_blocks
=
int
(
1024
*
1.5
),
tensors
=
{}
,
kv_cache_
tensors
=
[]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
f
"layer_
{
i
}
"
for
i
in
range
(
32
)],
KVCacheGroupSpec
([
f
"layer_
{
i
}
"
for
i
in
range
(
32
)],
full_attention_spec
),
full_attention_spec
),
...
@@ -648,7 +662,7 @@ def test_get_max_concurrency_for_kv_cache_config():
...
@@ -648,7 +662,7 @@ def test_get_max_concurrency_for_kv_cache_config():
kv_cache_config_sliding_window
=
KVCacheConfig
(
kv_cache_config_sliding_window
=
KVCacheConfig
(
num_blocks
=
129
*
3
,
num_blocks
=
129
*
3
,
tensors
=
{}
,
kv_cache_
tensors
=
[]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
f
"layer_
{
i
}
"
for
i
in
range
(
32
)],
KVCacheGroupSpec
([
f
"layer_
{
i
}
"
for
i
in
range
(
32
)],
sliding_window_spec
),
sliding_window_spec
),
...
@@ -660,7 +674,7 @@ def test_get_max_concurrency_for_kv_cache_config():
...
@@ -660,7 +674,7 @@ def test_get_max_concurrency_for_kv_cache_config():
kv_cache_config_hybrid_model
=
KVCacheConfig
(
kv_cache_config_hybrid_model
=
KVCacheConfig
(
num_blocks
=
(
1024
+
129
)
*
3
,
num_blocks
=
(
1024
+
129
)
*
3
,
tensors
=
{}
,
kv_cache_
tensors
=
[]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
f
"layer_
{
i
}
"
for
i
in
range
(
32
)],
KVCacheGroupSpec
([
f
"layer_
{
i
}
"
for
i
in
range
(
32
)],
full_attention_spec
),
full_attention_spec
),
...
@@ -678,9 +692,9 @@ def test_allocate_with_lookahead():
...
@@ -678,9 +692,9 @@ def test_allocate_with_lookahead():
block_size
=
4
block_size
=
4
config
=
KVCacheConfig
(
config
=
KVCacheConfig
(
num_blocks
=
10
,
num_blocks
=
10
,
tensors
=
{
kv_cache_
tensors
=
[
"layer1"
:
KVCacheTensor
(
100
),
KVCacheTensor
(
size
=
100
,
shared_by
=
[
"layer1"
]
),
}
,
]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer1"
],
KVCacheGroupSpec
([
"layer1"
],
new_kv_cache_spec
(
block_size
=
block_size
)),
new_kv_cache_spec
(
block_size
=
block_size
)),
...
@@ -702,7 +716,7 @@ def test_allocate_with_lookahead():
...
@@ -702,7 +716,7 @@ def test_allocate_with_lookahead():
num_new_tokens
=
3
,
num_new_tokens
=
3
,
num_lookahead_tokens
=
2
,
# Total required: 3+2=5 tokens
num_lookahead_tokens
=
2
,
# Total required: 3+2=5 tokens
)
)
assert
len
(
blocks
.
block
s
)
==
2
# ceil(5/4)=2 blocks
assert
len
(
blocks
.
get_
block
_ids
()[
0
]
)
==
2
# ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
# Test case 2: With precomputed blocks
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
kv_cache_manager
=
KVCacheManager
(
kv_cache_config
=
config
,
...
@@ -713,7 +727,7 @@ def test_allocate_with_lookahead():
...
@@ -713,7 +727,7 @@ def test_allocate_with_lookahead():
num_new_tokens
=
3
,
num_new_tokens
=
3
,
num_lookahead_tokens
=
2
,
num_lookahead_tokens
=
2
,
)
)
assert
len
(
blocks
.
block
s
)
==
2
assert
len
(
blocks
.
get_
block
_ids
()[
0
]
)
==
2
# Test case 3: With precomputed blocks
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
# required_blocks = ceil((3 + 4) / 4) = 2
...
@@ -724,4 +738,165 @@ def test_allocate_with_lookahead():
...
@@ -724,4 +738,165 @@ def test_allocate_with_lookahead():
num_new_tokens
=
3
,
num_new_tokens
=
3
,
num_lookahead_tokens
=
4
,
num_lookahead_tokens
=
4
,
)
)
assert
len
(
blocks
.
blocks
)
==
2
assert
len
(
blocks
.
get_block_ids
()[
0
])
==
2
def
test_get_kv_cache_config
():
# pass max_model_len to pass check_enough_kv_cache_memory
model_config
=
ModelConfig
(
max_model_len
=
16
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
)
mem_per_block_per_layer
=
16
*
2
*
64
*
4
*
2
# all layers are full attention -> single group
kv_cache_specs_full
=
{
'layer_1'
:
new_kv_cache_spec
(),
'layer_2'
:
new_kv_cache_spec
(),
}
kv_cache_config_full
=
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_full
,
mem_per_block_per_layer
*
2
*
32
)
assert
kv_cache_config_full
==
KVCacheConfig
(
num_blocks
=
32
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_1"
]),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_2"
]),
],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer_1"
,
"layer_2"
],
new_kv_cache_spec
())
])
# all layers are sliding window -> single group
kv_cache_specs_sliding
=
{
'layer_1'
:
new_sliding_window_spec
(),
'layer_2'
:
new_sliding_window_spec
(),
}
kv_cache_config_sliding
=
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_sliding
,
mem_per_block_per_layer
*
2
*
32
)
assert
kv_cache_config_sliding
==
KVCacheConfig
(
num_blocks
=
32
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_1"
]),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_2"
]),
],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer_1"
,
"layer_2"
],
new_sliding_window_spec
())
])
# full + sliding, but disable_hybrid_kv_cache_manager
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
True
kv_cache_specs_hybrid
=
{
'layer_1'
:
new_kv_cache_spec
(),
'layer_2'
:
new_sliding_window_spec
(),
}
kv_cache_config_hybrid
=
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_hybrid
,
mem_per_block_per_layer
*
2
*
32
)
assert
kv_cache_config_hybrid
==
KVCacheConfig
(
num_blocks
=
32
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_1"
]),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_2"
]),
],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer_1"
,
"layer_2"
],
new_kv_cache_spec
(
sliding_window
=
1
)),
],
)
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
False
# full + sliding, with hybrid_kv_cache_manager
kv_cache_specs_hybrid
=
{
'layer_1'
:
new_kv_cache_spec
(),
'layer_2'
:
new_sliding_window_spec
(),
}
kv_cache_config_hybrid
=
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_hybrid
,
mem_per_block_per_layer
*
2
*
32
)
assert
kv_cache_config_hybrid
==
KVCacheConfig
(
num_blocks
=
64
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
64
,
shared_by
=
[
"layer_1"
,
"layer_2"
]),
],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer_1"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer_2"
],
new_sliding_window_spec
()),
],
)
# 2 full + 4 sliding, 2 layers per group
kv_cache_specs_hybrid
=
{
'layer_1'
:
new_kv_cache_spec
(),
'layer_2'
:
new_kv_cache_spec
(),
'layer_3'
:
new_sliding_window_spec
(),
'layer_4'
:
new_sliding_window_spec
(),
'layer_5'
:
new_sliding_window_spec
(),
'layer_6'
:
new_sliding_window_spec
(),
}
kv_cache_config_hybrid
=
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_hybrid
,
mem_per_block_per_layer
*
2
*
32
)
assert
kv_cache_config_hybrid
==
KVCacheConfig
(
num_blocks
=
32
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_1"
,
"layer_3"
,
"layer_5"
]),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_2"
,
"layer_4"
,
"layer_6"
]),
],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer_1"
,
"layer_2"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer_3"
,
"layer_4"
],
new_sliding_window_spec
()),
KVCacheGroupSpec
([
"layer_5"
,
"layer_6"
],
new_sliding_window_spec
()),
],
)
# 3 full + 7 sliding, pad to 3 full + 9 sliding
kv_cache_specs_hybrid
=
{
'layer_1'
:
new_kv_cache_spec
(),
'layer_2'
:
new_kv_cache_spec
(),
'layer_3'
:
new_kv_cache_spec
(),
'layer_4'
:
new_sliding_window_spec
(),
'layer_5'
:
new_sliding_window_spec
(),
'layer_6'
:
new_sliding_window_spec
(),
'layer_7'
:
new_sliding_window_spec
(),
'layer_8'
:
new_sliding_window_spec
(),
'layer_9'
:
new_sliding_window_spec
(),
'layer_10'
:
new_sliding_window_spec
(),
}
kv_cache_config_hybrid
=
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_hybrid
,
mem_per_block_per_layer
*
3
*
32
)
assert
kv_cache_config_hybrid
==
KVCacheConfig
(
num_blocks
=
32
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_1"
,
"layer_4"
,
"layer_7"
,
"layer_10"
]),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_2"
,
"layer_5"
,
"layer_8"
]),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_3"
,
"layer_6"
,
"layer_9"
]),
],
kv_cache_groups
=
[
KVCacheGroupSpec
([
"layer_1"
,
"layer_2"
,
"layer_3"
],
new_kv_cache_spec
()),
KVCacheGroupSpec
([
"layer_4"
,
"layer_5"
,
"layer_6"
],
new_sliding_window_spec
()),
KVCacheGroupSpec
([
"layer_7"
,
"layer_8"
,
"layer_9"
],
new_sliding_window_spec
()),
KVCacheGroupSpec
([
"layer_10"
],
new_sliding_window_spec
()),
],
)
# different hidden size, unimplemented
kv_cache_specs_hybrid
=
{
'layer_1'
:
new_kv_cache_spec
(
head_size
=
128
),
'layer_2'
:
new_kv_cache_spec
(),
}
with
pytest
.
raises
(
NotImplementedError
):
get_kv_cache_config
(
vllm_config
,
kv_cache_specs_hybrid
,
mem_per_block_per_layer
*
2
*
32
)
tests/v1/core/test_prefix_caching.py
View file @
f8a1a2d1
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compare the with and without prefix caching."""
"""Compare the with and without prefix caching."""
import
copy
from
typing
import
Optional
from
typing
import
Optional
import
pytest
import
pytest
...
@@ -13,8 +14,8 @@ from vllm.sampling_params import SamplingParams
...
@@ -13,8 +14,8 @@ from vllm.sampling_params import SamplingParams
from
vllm.utils
import
sha256
from
vllm.utils
import
sha256
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
hash_block_tokens
)
KVCacheBlock
,
hash_block_tokens
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
SlidingWindowSpec
)
KVCacheGroupSpec
,
SlidingWindowSpec
)
...
@@ -47,7 +48,7 @@ def make_request(request_id,
...
@@ -47,7 +48,7 @@ def make_request(request_id,
def
make_kv_cache_config
(
block_size
:
int
,
num_blocks
:
int
)
->
KVCacheConfig
:
def
make_kv_cache_config
(
block_size
:
int
,
num_blocks
:
int
)
->
KVCacheConfig
:
return
KVCacheConfig
(
return
KVCacheConfig
(
num_blocks
=
num_blocks
,
num_blocks
=
num_blocks
,
tensors
=
{}
,
kv_cache_
tensors
=
[]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
(
KVCacheGroupSpec
(
[
"layer"
],
[
"layer"
],
...
@@ -57,6 +58,38 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
...
@@ -57,6 +58,38 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
)
)
def
make_kv_cache_config_hybrid_model
(
block_size
:
int
,
num_blocks
:
int
)
->
KVCacheConfig
:
return
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
[
KVCacheGroupSpec
(
[
"layer1"
],
FullAttentionSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
False
),
),
KVCacheGroupSpec
(
[
"layer2"
],
SlidingWindowSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
False
,
sliding_window
=
2
*
block_size
),
),
KVCacheGroupSpec
(
[
"layer3"
],
SlidingWindowSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
False
,
sliding_window
=
2
*
block_size
),
),
],
)
@
pytest
.
mark
.
parametrize
(
"hash_algo"
,
[
"sha256"
,
"hash"
])
@
pytest
.
mark
.
parametrize
(
"hash_algo"
,
[
"sha256"
,
"hash"
])
def
test_prefill
(
hash_algo
):
def
test_prefill
(
hash_algo
):
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
...
@@ -79,10 +112,10 @@ def test_prefill(hash_algo):
...
@@ -79,10 +112,10 @@ def test_prefill(hash_algo):
req0
=
make_request
(
"0"
,
all_token_ids
)
req0
=
make_request
(
"0"
,
all_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
])
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
])
==
3
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
...
@@ -92,7 +125,8 @@ def test_prefill(hash_algo):
...
@@ -92,7 +125,8 @@ def test_prefill(hash_algo):
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
==
block_hash
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
.
block_hash
==
block_hash
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
.
hash_value
...
@@ -111,10 +145,10 @@ def test_prefill(hash_algo):
...
@@ -111,10 +145,10 @@ def test_prefill(hash_algo):
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
5
]]
assert
blocks
.
get_block_ids
()
==
[[
5
]]
for
block
in
computed_blocks
.
blocks
:
for
block
in
computed_blocks
.
blocks
[
0
]
:
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
# At this point, we should have 5 free blocks left.
# At this point, we should have 5 free blocks left.
...
@@ -145,7 +179,7 @@ def test_prefill(hash_algo):
...
@@ -145,7 +179,7 @@ def test_prefill(hash_algo):
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
6
]]
assert
blocks
.
get_block_ids
()
==
[[
6
]]
...
@@ -165,10 +199,10 @@ def test_prefill(hash_algo):
...
@@ -165,10 +199,10 @@ def test_prefill(hash_algo):
# Cache miss and eviction.
# Cache miss and eviction.
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
10
))
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
10
))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req3
,
16
*
10
,
blocks
=
manager
.
allocate_slots
(
req3
,
16
*
10
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
# This block ID order also checks the eviction order.
# This block ID order also checks the eviction order.
assert
blocks
.
get_block_ids
()
==
[[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]]
assert
blocks
.
get_block_ids
()
==
[[
7
,
8
,
9
,
10
,
4
,
5
,
6
,
3
,
2
,
1
]]
...
@@ -177,6 +211,138 @@ def test_prefill(hash_algo):
...
@@ -177,6 +211,138 @@ def test_prefill(hash_algo):
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
assert
manager
.
block_pool
.
free_block_queue
.
free_list_tail
is
None
def
test_prefill_hybrid_model
():
block_size
=
16
manager
=
KVCacheManager
(
make_kv_cache_config_hybrid_model
(
block_size
,
21
),
max_model_len
=
8192
,
enable_caching
=
True
,
)
hash_fn
=
hash
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
])
==
3
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
[
9
,
10
,
11
,
12
]]
# Check full block metadata
parent_block_hash
=
None
for
length
,
block_ids
in
zip
((
1
,
2
,
3
),
((
1
,
5
,
9
),
(
2
,
6
,
10
),
(
3
,
7
,
11
))):
block_tokens
=
tuple
(
all_token_ids
[(
length
-
1
)
*
16
:
length
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
for
block_id
in
block_ids
:
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
.
block_hash
==
block_hash
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
# Check partial block metadata
for
block_id
in
(
4
,
8
,
12
):
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
computed_blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
],
[
0
,
6
,
7
],
[
0
,
10
,
11
]]
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
13
],
[
14
],
[
15
]]
for
block_per_group
in
computed_blocks
.
blocks
:
for
block
in
block_per_group
:
if
block
!=
manager
.
block_pool
.
null_block
:
assert
block
.
ref_cnt
==
2
block_hashes
=
manager
.
req_to_block_hashes
[
req1
.
request_id
]
manager
.
free
(
req0
)
manager
.
free
(
req1
)
cached_block_hash_to_block_bak
=
copy
.
copy
(
manager
.
block_pool
.
cached_block_hash_to_block
)
def
test_partial_request_hit
(
request_id
:
str
,
hash_to_evict
:
list
[
BlockHashWithGroupId
],
expect_hit_length
:
int
):
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
)
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
hash_with_group_id
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
len
(
manager
.
req_to_block_hashes
[
req
.
request_id
])
==
3
assert
num_computed_tokens
==
expect_hit_length
*
block_size
for
block_per_group
in
computed_blocks
.
blocks
:
assert
len
(
block_per_group
)
==
num_computed_tokens
//
block_size
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
[
hash_with_group_id
]
=
cached_block_hash_to_block_bak
[
hash_with_group_id
]
manager
.
free
(
req
)
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit
(
"2"
,
[
BlockHashWithGroupId
(
block_hashes
[
0
],
1
),
BlockHashWithGroupId
(
block_hashes
[
0
],
2
)
],
3
)
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit
(
"3"
,
[
BlockHashWithGroupId
(
block_hashes
[
0
],
0
),
],
0
)
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit
(
"4"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
0
),
BlockHashWithGroupId
(
block_hashes
[
2
],
1
),
BlockHashWithGroupId
(
block_hashes
[
2
],
2
),
],
2
)
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit
(
"5"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
0
)],
2
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"6"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
1
)],
2
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"7"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
2
)],
2
)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
# The cache hit length of full attention is 1 * block_size.
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
test_partial_request_hit
(
"8"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
0
),
BlockHashWithGroupId
(
block_hashes
[
0
],
1
),
BlockHashWithGroupId
(
block_hashes
[
0
],
2
),
],
0
)
def
test_prefill_plp
():
def
test_prefill_plp
():
'''Test prefill with APC and some prompt logprobs (plp) requests.
'''Test prefill with APC and some prompt logprobs (plp) requests.
...
@@ -203,13 +369,13 @@ def test_prefill_plp():
...
@@ -203,13 +369,13 @@ def test_prefill_plp():
req0
=
make_request
(
"0"
,
all_token_ids
,
prompt_logprobs
=
5
)
req0
=
make_request
(
"0"
,
all_token_ids
,
prompt_logprobs
=
5
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
])
==
0
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
])
==
0
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
.
blocks
]
req0_block_hashes
=
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]
]
# Check full block metadata
# Check full block metadata
parent_block_hash
=
None
parent_block_hash
=
None
...
@@ -217,7 +383,8 @@ def test_prefill_plp():
...
@@ -217,7 +383,8 @@ def test_prefill_plp():
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
==
block_hash
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
.
block_hash
==
block_hash
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
.
hash_value
...
@@ -237,10 +404,10 @@ def test_prefill_plp():
...
@@ -237,10 +404,10 @@ def test_prefill_plp():
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
5
]]
assert
blocks
.
get_block_ids
()
==
[[
5
]]
for
block
in
computed_blocks
.
blocks
:
for
block
in
computed_blocks
.
blocks
[
0
]
:
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
# At this point, we should have 5 free blocks left.
# At this point, we should have 5 free blocks left.
...
@@ -269,14 +436,14 @@ def test_prefill_plp():
...
@@ -269,14 +436,14 @@ def test_prefill_plp():
prompt_logprobs
=
5
)
prompt_logprobs
=
5
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
0
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
])
==
0
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req2
,
55
,
blocks
=
manager
.
allocate_slots
(
req2
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
block_ids
=
blocks
.
get_block_ids
()
block_ids
=
blocks
.
get_block_ids
()
# Duplicate cached blocks have different ids but same hashes vs request #0
# Duplicate cached blocks have different ids but same hashes vs request #0
assert
[
b
.
block_hash
for
b
in
blocks
.
blocks
]
==
req0_block_hashes
assert
[
b
.
block_hash
for
b
in
blocks
.
blocks
[
0
]
]
==
req0_block_hashes
assert
block_ids
!=
[[
1
,
2
,
3
,
4
]]
assert
block_ids
!=
[[
1
,
2
,
3
,
4
]]
# Request #2 block hashes are valid since request #0 hashes are.
# Request #2 block hashes are valid since request #0 hashes are.
...
@@ -302,10 +469,10 @@ def test_decode():
...
@@ -302,10 +469,10 @@ def test_decode():
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
)
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
...
@@ -314,10 +481,10 @@ def test_decode():
...
@@ -314,10 +481,10 @@ def test_decode():
for
_
in
range
(
4
):
for
_
in
range
(
4
):
req0
.
append_output_token_ids
(
8
)
req0
.
append_output_token_ids
(
8
)
new_blocks
=
manager
.
allocate_slots
(
req0
,
4
,
new_blocks
=
manager
.
allocate_slots
(
req0
,
4
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
)
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
]
)
==
0
assert
manager
.
single_type_manager
.
req_to_blocks
[
assert
manager
.
coordinator
.
single_type_manager
s
[
0
]
.
req_to_blocks
[
req0
.
request_id
][
-
1
].
block_hash
is
None
req0
.
request_id
][
-
1
].
block_hash
is
None
# Append slots with allocating a new block.
# Append slots with allocating a new block.
...
@@ -327,12 +494,12 @@ def test_decode():
...
@@ -327,12 +494,12 @@ def test_decode():
for
_
in
range
(
9
+
10
):
for
_
in
range
(
9
+
10
):
req0
.
append_output_token_ids
(
7
)
req0
.
append_output_token_ids
(
7
)
new_blocks
=
manager
.
allocate_slots
(
req0
,
19
,
new_blocks
=
manager
.
allocate_slots
(
req0
,
19
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
)
==
1
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
]
)
==
1
assert
manager
.
single_type_manager
.
req_to_blocks
[
assert
manager
.
coordinator
.
single_type_manager
s
[
0
]
.
req_to_blocks
[
req0
.
request_id
][
-
2
].
block_hash
is
not
None
req0
.
request_id
][
-
2
].
block_hash
is
not
None
assert
manager
.
single_type_manager
.
req_to_blocks
[
assert
manager
.
coordinator
.
single_type_manager
s
[
0
]
.
req_to_blocks
[
req0
.
request_id
][
-
1
].
block_hash
is
None
req0
.
request_id
][
-
1
].
block_hash
is
None
...
@@ -346,23 +513,23 @@ def test_evict():
...
@@ -346,23 +513,23 @@ def test_evict():
last_token_id
=
5
*
16
+
7
last_token_id
=
5
*
16
+
7
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)))
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
5
*
16
+
7
,
blocks
=
manager
.
allocate_slots
(
req0
,
5
*
16
+
7
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
len
(
blocks
.
blocks
)
==
6
# 5 full + 1 partial
assert
len
(
blocks
.
blocks
[
0
]
)
==
6
# 5 full + 1 partial
# 3 blocks.
# 3 blocks.
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
last_token_id
+
3
*
16
)))
last_token_id
+
3
*
16
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
3
*
16
,
blocks
=
manager
.
allocate_slots
(
req1
,
3
*
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
len
(
blocks
.
blocks
)
==
3
# 3 full blocks
assert
len
(
blocks
.
blocks
[
0
]
)
==
3
# 3 full blocks
last_token_id
+=
3
*
16
last_token_id
+=
3
*
16
# 10 - (6 + 3) == 1
# 10 - (6 + 3) == 1
...
@@ -382,7 +549,7 @@ def test_evict():
...
@@ -382,7 +549,7 @@ def test_evict():
assert
computed_blocks
.
get_block_ids
()
==
[[
1
,
2
]]
assert
computed_blocks
.
get_block_ids
()
==
[[
1
,
2
]]
assert
num_computed_tokens
==
2
*
16
assert
num_computed_tokens
==
2
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
blocks
=
manager
.
allocate_slots
(
req2
,
3
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
10
]]
assert
blocks
.
get_block_ids
()
==
[[
10
]]
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
7
...
@@ -404,12 +571,12 @@ def test_hash_block_correct_reuse():
...
@@ -404,12 +571,12 @@ def test_hash_block_correct_reuse():
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
,
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
len
(
blocks
.
blocks
)
==
1
assert
len
(
blocks
.
blocks
[
0
]
)
==
1
# Deallocate the block.
# Deallocate the block.
manager
.
free
(
req
)
manager
.
free
(
req
)
...
@@ -418,15 +585,15 @@ def test_hash_block_correct_reuse():
...
@@ -418,15 +585,15 @@ def test_hash_block_correct_reuse():
# block is cleared.
# block is cleared.
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)))
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
-
1
,
blocks
=
manager
.
allocate_slots
(
req
,
num_tokens
-
1
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
len
(
blocks
.
blocks
)
==
1
assert
len
(
blocks
.
blocks
[
0
]
)
==
1
assert
manager
.
block_pool
.
blocks
[
assert
manager
.
block_pool
.
blocks
[
blocks
.
blocks
[
0
]
blocks
.
blocks
[
0
].
block_id
].
block_hash
is
None
[
0
].
block_id
].
block_hash
is
None
def
test_computed_blocks_not_evicted
():
def
test_computed_blocks_not_evicted
():
...
@@ -445,24 +612,24 @@ def test_computed_blocks_not_evicted():
...
@@ -445,24 +612,24 @@ def test_computed_blocks_not_evicted():
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
num_tokens
,
blocks
=
manager
.
allocate_slots
(
req0
,
num_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
len
(
blocks
.
blocks
)
==
1
assert
len
(
blocks
.
blocks
[
0
]
)
==
1
assert
blocks
.
blocks
[
0
].
block_id
==
1
assert
blocks
.
blocks
[
0
]
[
0
]
.
block_id
==
1
# Allocate another block.
# Allocate another block.
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)))
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
num_tokens
,
blocks
=
manager
.
allocate_slots
(
req1
,
num_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
len
(
blocks
.
blocks
)
==
1
assert
len
(
blocks
.
blocks
[
0
]
)
==
1
assert
blocks
.
blocks
[
0
].
block_id
==
2
assert
blocks
.
blocks
[
0
]
[
0
]
.
block_id
==
2
# Free the blocks.
# Free the blocks.
manager
.
free
(
req0
)
manager
.
free
(
req0
)
...
@@ -472,15 +639,15 @@ def test_computed_blocks_not_evicted():
...
@@ -472,15 +639,15 @@ def test_computed_blocks_not_evicted():
# cached block rather than the first one.
# cached block rather than the first one.
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)))
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
)
==
1
assert
len
(
computed_blocks
.
blocks
[
0
]
)
==
1
assert
computed_blocks
.
blocks
[
0
].
block_id
==
1
assert
computed_blocks
.
blocks
[
0
]
[
0
]
.
block_id
==
1
assert
num_computed_tokens
==
block_size
assert
num_computed_tokens
==
block_size
blocks
=
manager
.
allocate_slots
(
req2
,
num_tokens
*
2
-
num_tokens
,
blocks
=
manager
.
allocate_slots
(
req2
,
num_tokens
*
2
-
num_tokens
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
len
(
blocks
.
blocks
)
==
1
assert
len
(
blocks
.
blocks
[
0
]
)
==
1
assert
blocks
.
blocks
[
0
].
block_id
==
2
assert
blocks
.
blocks
[
0
]
[
0
]
.
block_id
==
2
def
test_basic_prefix_caching_disabled
():
def
test_basic_prefix_caching_disabled
():
...
@@ -497,12 +664,12 @@ def test_basic_prefix_caching_disabled():
...
@@ -497,12 +664,12 @@ def test_basic_prefix_caching_disabled():
req1
=
make_request
(
"1"
,
list
(
range
(
10
)))
# 2 blocks and some more
req1
=
make_request
(
"1"
,
list
(
range
(
10
)))
# 2 blocks and some more
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req1
,
10
,
blocks
=
manager
.
allocate_slots
(
req1
,
10
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
len
(
blocks
.
blocks
)
==
3
assert
len
(
blocks
.
blocks
[
0
]
)
==
3
# Free the blocks.
# Free the blocks.
manager
.
free
(
req1
)
manager
.
free
(
req1
)
...
@@ -510,20 +677,20 @@ def test_basic_prefix_caching_disabled():
...
@@ -510,20 +677,20 @@ def test_basic_prefix_caching_disabled():
# No caching.
# No caching.
req2
=
make_request
(
"2"
,
list
(
range
(
16
)))
# shared prefix
req2
=
make_request
(
"2"
,
list
(
range
(
16
)))
# shared prefix
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req2
,
16
,
blocks
=
manager
.
allocate_slots
(
req2
,
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
len
(
blocks
.
blocks
)
==
4
assert
len
(
blocks
.
blocks
[
0
]
)
==
4
# New requests should not have any blocks.
# New requests should not have any blocks.
req3
=
make_request
(
"3"
,
list
(
range
(
4
)))
req3
=
make_request
(
"3"
,
list
(
range
(
4
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req3
,
4
,
blocks
=
manager
.
allocate_slots
(
req3
,
4
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
not
blocks
assert
not
blocks
...
@@ -558,6 +725,7 @@ def test_cache_blocks(hash_fn):
...
@@ -558,6 +725,7 @@ def test_cache_blocks(hash_fn):
num_full_blocks
=
2
,
num_full_blocks
=
2
,
block_size
=
block_size
,
block_size
=
block_size
,
hash_fn
=
hash_fn
,
hash_fn
=
hash_fn
,
kv_cache_group_id
=
0
,
)
)
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
2
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
2
...
@@ -573,11 +741,83 @@ def test_cache_blocks(hash_fn):
...
@@ -573,11 +741,83 @@ def test_cache_blocks(hash_fn):
num_full_blocks
=
3
,
num_full_blocks
=
3
,
block_size
=
block_size
,
block_size
=
block_size
,
hash_fn
=
hash_fn
,
hash_fn
=
hash_fn
,
kv_cache_group_id
=
0
,
)
)
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
3
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
3
assert
blocks
[
0
].
block_hash
is
not
None
assert
blocks
[
0
].
block_hash
is
not
None
def
test_cache_blocks_multi_group
():
"""
This tests that blocks are cached correctly for different kv cache groups.
"""
block_size
=
4
block_pool
=
BlockPool
(
num_gpu_blocks
=
10
,
enable_caching
=
True
)
# Req:
# Block 0/4: [0, 1, 2, 3]
# Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13]
req
=
make_request
(
"0"
,
list
(
range
(
14
)))
# Cache the blocks for group 0.
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
block_hashes
:
list
[
BlockHash
]
=
[]
block_pool
.
cache_full_blocks
(
request
=
req
,
blocks
=
blocks
,
block_hashes
=
block_hashes
,
num_cached_blocks
=
0
,
num_full_blocks
=
2
,
block_size
=
block_size
,
hash_fn
=
hash
,
kv_cache_group_id
=
0
,
)
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
2
assert
len
(
block_hashes
)
==
2
assert
all
([
block
.
block_hash
is
not
None
for
block
in
blocks
])
# Cache the blocks for group 1.
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
3
)]
block_pool
.
cache_full_blocks
(
request
=
req
,
blocks
=
blocks
,
block_hashes
=
block_hashes
,
num_cached_blocks
=
0
,
num_full_blocks
=
3
,
block_size
=
block_size
,
hash_fn
=
hash
,
kv_cache_group_id
=
1
,
)
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
5
assert
len
(
block_hashes
)
==
3
assert
all
([
block
.
block_hash
is
not
None
for
block
in
blocks
])
# Block hash 0: hit for group 0 and 1
# Block hash 1: hit for group 0 and 1
# Block hash 2: hit for group 1
assert
block_pool
.
get_cached_block
(
block_hashes
[
0
],
kv_cache_group_ids
=
[
0
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
1
],
kv_cache_group_ids
=
[
0
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
2
],
kv_cache_group_ids
=
[
0
])
is
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
0
],
kv_cache_group_ids
=
[
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
1
],
kv_cache_group_ids
=
[
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
2
],
kv_cache_group_ids
=
[
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
0
],
kv_cache_group_ids
=
[
0
,
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
1
],
kv_cache_group_ids
=
[
0
,
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
2
],
kv_cache_group_ids
=
[
0
,
1
])
is
None
def
test_mm_prefix_caching
():
def
test_mm_prefix_caching
():
"""
"""
This tests that the multi-modal prefix caching is correct.
This tests that the multi-modal prefix caching is correct.
...
@@ -614,7 +854,7 @@ def test_mm_prefix_caching():
...
@@ -614,7 +854,7 @@ def test_mm_prefix_caching():
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes with extra keys.
# Completed block should have hashes with extra keys.
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
manager
.
req_to_block_hashes
[
req0
.
request_id
]
block_hashes
=
manager
.
req_to_block_hashes
[
req0
.
request_id
]
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
...
@@ -623,7 +863,7 @@ def test_mm_prefix_caching():
...
@@ -623,7 +863,7 @@ def test_mm_prefix_caching():
assert
block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
assert
block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
...
@@ -632,9 +872,9 @@ def test_mm_prefix_caching():
...
@@ -632,9 +872,9 @@ def test_mm_prefix_caching():
for
_
in
range
(
5
):
for
_
in
range
(
5
):
req0
.
append_output_token_ids
(
8
)
req0
.
append_output_token_ids
(
8
)
new_blocks
=
manager
.
allocate_slots
(
req0
,
5
,
new_blocks
=
manager
.
allocate_slots
(
req0
,
5
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
)
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
]
)
==
0
# The just completed block should have hashes with extra keys.
# The just completed block should have hashes with extra keys.
assert
len
(
block_hashes
)
==
4
assert
len
(
block_hashes
)
==
4
...
@@ -652,7 +892,7 @@ def test_mm_prefix_caching():
...
@@ -652,7 +892,7 @@ def test_mm_prefix_caching():
mm_positions
=
mm_positions
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
mm_hashes
=
mm_hashes
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
computed_blocks
.
blocks
)
==
3
assert
len
(
computed_blocks
.
blocks
[
0
]
)
==
3
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
...
@@ -675,7 +915,7 @@ def test_cache_key_salting():
...
@@ -675,7 +915,7 @@ def test_cache_key_salting():
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes with extra keys.
# Completed block should have hashes with extra keys.
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
manager
.
req_to_block_hashes
[
req0
.
request_id
]
block_hashes
=
manager
.
req_to_block_hashes
[
req0
.
request_id
]
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
...
@@ -684,7 +924,7 @@ def test_cache_key_salting():
...
@@ -684,7 +924,7 @@ def test_cache_key_salting():
assert
block_hashes
[
2
].
extra_keys
is
None
assert
block_hashes
[
2
].
extra_keys
is
None
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
assert
blocks
.
get_block_ids
()
==
[[
1
,
2
,
3
,
4
]]
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
...
@@ -693,9 +933,9 @@ def test_cache_key_salting():
...
@@ -693,9 +933,9 @@ def test_cache_key_salting():
for
_
in
range
(
5
):
for
_
in
range
(
5
):
req0
.
append_output_token_ids
(
8
)
req0
.
append_output_token_ids
(
8
)
new_blocks
=
manager
.
allocate_slots
(
req0
,
5
,
new_blocks
=
manager
.
allocate_slots
(
req0
,
5
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
)
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
]
)
==
0
# Now one more block that should not have extra keys.
# Now one more block that should not have extra keys.
assert
len
(
block_hashes
)
==
4
assert
len
(
block_hashes
)
==
4
...
@@ -706,14 +946,14 @@ def test_cache_key_salting():
...
@@ -706,14 +946,14 @@ def test_cache_key_salting():
req1
=
make_request
(
"1"
,
token_ids
,
cache_salt
=
"salt1"
)
req1
=
make_request
(
"1"
,
token_ids
,
cache_salt
=
"salt1"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
# Should match only a prefix of 3 blocks.
# Should match only a prefix of 3 blocks.
assert
len
(
computed_blocks
.
blocks
)
==
3
assert
len
(
computed_blocks
.
blocks
[
0
]
)
==
3
assert
num_computed_tokens
==
3
*
block_size
assert
num_computed_tokens
==
3
*
block_size
# Test cache miss with same content but different salt.
# Test cache miss with same content but different salt.
token_ids
=
common_token_ids
+
[
4
]
*
11
token_ids
=
common_token_ids
+
[
4
]
*
11
req2
=
make_request
(
"2"
,
token_ids
,
cache_salt
=
"salt2"
)
req2
=
make_request
(
"2"
,
token_ids
,
cache_salt
=
"salt2"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
)
==
0
assert
len
(
computed_blocks
.
blocks
[
0
]
)
==
0
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
manager
.
req_to_block_hashes
[
req2
.
request_id
]
block_hashes
=
manager
.
req_to_block_hashes
[
req2
.
request_id
]
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
...
@@ -738,20 +978,24 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -738,20 +978,24 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
req0
=
make_request
(
"0"
,
common_token_ids
)
req0
=
make_request
(
"0"
,
common_token_ids
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
manager
.
allocate_slots
(
req0
,
48
,
manager
.
allocate_slots
(
req0
,
48
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
block_part0
=
manager
.
single_type_manager
.
req_to_blocks
[
req0
.
request_id
]
computed_blocks
)
block_part0
=
manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
req0
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1
=
make_request
(
"1"
,
common_token_ids
*
2
)
req1
=
make_request
(
"1"
,
common_token_ids
*
2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
computed_blocks
.
blocks
==
block_part0
assert
computed_blocks
.
blocks
[
0
]
==
block_part0
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
manager
.
allocate_slots
(
req1
,
48
,
manager
.
allocate_slots
(
req1
,
48
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
block_part1
=
manager
.
single_type_manager
.
req_to_blocks
[
req1
.
request_id
]
computed_blocks
)
block_part1
=
manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
req1
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... |
# | Req1-5(F)| ... |
manager
.
free
(
req1
)
manager
.
free
(
req1
)
...
@@ -762,10 +1006,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -762,10 +1006,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
)
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
manager
.
allocate_slots
(
req2
,
block_size
*
2
,
manager
.
allocate_slots
(
req2
,
block_size
*
2
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# but it cannot be allocated due to insufficient free blocks (2).
...
@@ -773,11 +1018,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -773,11 +1018,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
req3
=
make_request
(
"3"
,
common_token_ids
*
3
)
req3
=
make_request
(
"3"
,
common_token_ids
*
3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
computed_blocks
.
blocks
==
block_part1
assert
computed_blocks
.
blocks
[
0
]
==
block_part1
assert
num_computed_tokens
==
6
*
16
assert
num_computed_tokens
==
6
*
16
# Req3 cannot be allocated.
# Req3 cannot be allocated.
assert
manager
.
allocate_slots
(
req3
,
48
,
assert
manager
.
allocate_slots
(
req3
,
48
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
is
None
computed_blocks
)
is
None
# Block 0-2 are used by Req 1.
# Block 0-2 are used by Req 1.
assert
{
block
.
ref_cnt
for
block
in
block_part1
[:
3
]}
==
{
1
}
assert
{
block
.
ref_cnt
for
block
in
block_part1
[:
3
]}
==
{
1
}
...
@@ -804,9 +1049,9 @@ def test_reset_prefix_cache():
...
@@ -804,9 +1049,9 @@ def test_reset_prefix_cache():
req1
=
make_request
(
"1"
,
all_token_ids
)
req1
=
make_request
(
"1"
,
all_token_ids
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
])
==
3
assert
len
(
computed_blocks
.
blocks
)
==
3
assert
len
(
computed_blocks
.
blocks
[
0
]
)
==
3
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
len
(
computed_blocks
.
blocks
)
*
16
,
len
(
computed_blocks
.
blocks
[
0
]
)
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
.
get_block_ids
()
==
[[
5
]]
assert
blocks
.
get_block_ids
()
==
[[
5
]]
...
@@ -836,10 +1081,11 @@ def test_prefix_cache_stats_disabled():
...
@@ -836,10 +1081,11 @@ def test_prefix_cache_stats_disabled():
# Call all functions that check whether log_stats is disabled.
# Call all functions that check whether log_stats is disabled.
req
=
make_request
(
"0"
,
list
(
range
(
16
)))
req
=
make_request
(
"0"
,
list
(
range
(
16
)))
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
manager
.
allocate_slots
(
req
,
16
,
manager
.
allocate_slots
(
req
,
16
,
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
manager
.
reset_prefix_cache
()
manager
.
reset_prefix_cache
()
# Ensure prefix_cache_stats remains None
# Ensure prefix_cache_stats remains None
...
@@ -918,7 +1164,8 @@ def test_eagle_enabled_removes_last_block():
...
@@ -918,7 +1164,8 @@ def test_eagle_enabled_removes_last_block():
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with same tokens + Eagle enabled
# New request with same tokens + Eagle enabled
...
@@ -928,7 +1175,7 @@ def test_eagle_enabled_removes_last_block():
...
@@ -928,7 +1175,7 @@ def test_eagle_enabled_removes_last_block():
# Should retain 1 block:
# Should retain 1 block:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. drop last matched block → 1 remaining block
# 2. drop last matched block → 1 remaining block
assert
len
(
computed_blocks
.
blocks
)
==
1
assert
len
(
computed_blocks
.
blocks
[
0
]
)
==
1
assert
num_tokens
==
1
*
block_size
# 16 tokens
assert
num_tokens
==
1
*
block_size
# 16 tokens
...
@@ -948,14 +1195,15 @@ def test_eagle_with_partial_blocks():
...
@@ -948,14 +1195,15 @@ def test_eagle_with_partial_blocks():
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with Eagle enabled
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
)
==
1
assert
len
(
computed_blocks
.
blocks
[
0
]
)
==
1
assert
num_tokens
==
1
*
block_size
assert
num_tokens
==
1
*
block_size
...
@@ -973,7 +1221,7 @@ def test_eagle_with_sliding_window():
...
@@ -973,7 +1221,7 @@ def test_eagle_with_sliding_window():
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
KVCacheConfig
(
KVCacheConfig
(
num_blocks
=
10
,
num_blocks
=
10
,
tensors
=
{}
,
kv_cache_
tensors
=
[]
,
kv_cache_groups
=
[
KVCacheGroupSpec
([
'layer'
],
sliding_window_spec
)],
kv_cache_groups
=
[
KVCacheGroupSpec
([
'layer'
],
sliding_window_spec
)],
),
),
max_model_len
=
8192
,
max_model_len
=
8192
,
...
@@ -988,7 +1236,8 @@ def test_eagle_with_sliding_window():
...
@@ -988,7 +1236,8 @@ def test_eagle_with_sliding_window():
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
manager
.
allocate_slots
(
req
,
len
(
token_ids
),
len
(
computed_blocks
.
blocks
)
*
16
,
computed_blocks
)
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
# record the block hash of the first block in the request for later use
# record the block hash of the first block in the request for later use
block_hash_first_block
=
manager
.
req_to_block_hashes
[
req
.
request_id
][
0
]
block_hash_first_block
=
manager
.
req_to_block_hashes
[
req
.
request_id
][
0
]
assert
block_hash_first_block
is
not
None
assert
block_hash_first_block
is
not
None
...
@@ -998,13 +1247,14 @@ def test_eagle_with_sliding_window():
...
@@ -998,13 +1247,14 @@ def test_eagle_with_sliding_window():
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
)
==
1
assert
len
(
computed_blocks
.
blocks
[
0
]
)
==
1
assert
num_tokens
==
1
*
block_size
assert
num_tokens
==
1
*
block_size
# Evict the first block in the request
# Evict the first block in the request
assert
manager
.
block_pool
.
get_cached_block
(
assert
manager
.
block_pool
.
get_cached_block
(
block_hash_first_block
)
is
not
None
block_hash_first_block
,
kv_cache_group_ids
=
[
0
])
is
not
None
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
block_hash_first_block
)
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
BlockHashWithGroupId
(
block_hash_first_block
,
0
))
# New request
# New request
req_after_evict
=
make_request
(
"partial_eagle_after_evict"
,
token_ids
)
req_after_evict
=
make_request
(
"partial_eagle_after_evict"
,
token_ids
)
...
@@ -1012,5 +1262,5 @@ def test_eagle_with_sliding_window():
...
@@ -1012,5 +1262,5 @@ def test_eagle_with_sliding_window():
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# not considered. But after dropping the last matched block due to eagle,
# there will be no matched prefix.
# there will be no matched prefix.
assert
len
(
computed_blocks
.
blocks
)
==
0
assert
len
(
computed_blocks
.
blocks
[
0
]
)
==
0
assert
num_tokens
==
0
assert
num_tokens
==
0
tests/v1/core/test_scheduler.py
View file @
f8a1a2d1
...
@@ -97,7 +97,7 @@ def create_scheduler(
...
@@ -97,7 +97,7 @@ def create_scheduler(
)
)
kv_cache_config
=
KVCacheConfig
(
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
tensors
=
{}
,
kv_cache_
tensors
=
[]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
'layer'
],
KVCacheGroupSpec
([
'layer'
],
FullAttentionSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
FullAttentionSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
...
@@ -814,10 +814,10 @@ def _assert_right_kv_cache_manager(
...
@@ -814,10 +814,10 @@ def _assert_right_kv_cache_manager(
# Make sure the request stats are right.
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS
=
num_tokens
//
block_size
EXPECTED_TOTAL_BLOCKS
=
num_tokens
//
block_size
for
req_id
in
req_ids
:
for
req_id
in
req_ids
:
blocks
=
(
scheduler
.
kv_cache_manager
.
single_type_manage
r
.
blocks
=
(
scheduler
.
kv_cache_manager
.
coordinato
r
.
req_to_blocks
[
req_id
])
single_type_managers
[
0
].
req_to_blocks
[
req_id
])
hashes
=
scheduler
.
kv_cache_manager
.
req_to_block_hashes
[
req_id
]
hashes
=
scheduler
.
kv_cache_manager
.
req_to_block_hashes
[
req_id
]
assert
(
scheduler
.
kv_cache_manager
.
single_type_manager
.
assert
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_manager
s
[
0
]
.
num_cached_block
[
req_id
]
==
EXPECTED_TOTAL_BLOCKS
)
num_cached_block
[
req_id
]
==
EXPECTED_TOTAL_BLOCKS
)
assert
len
(
blocks
)
==
EXPECTED_TOTAL_BLOCKS
assert
len
(
blocks
)
==
EXPECTED_TOTAL_BLOCKS
assert
len
(
hashes
)
==
EXPECTED_TOTAL_BLOCKS
assert
len
(
hashes
)
==
EXPECTED_TOTAL_BLOCKS
...
@@ -1198,11 +1198,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
...
@@ -1198,11 +1198,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert
len
(
scheduler
.
encoder_cache_manager
.
cached
)
==
0
assert
len
(
scheduler
.
encoder_cache_manager
.
cached
)
==
0
# KVCache Manager.
# KVCache Manager.
assert
len
(
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
scheduler
.
kv_cache_manager
.
single_type_manager
.
req_to_blocks
)
==
0
req_to_blocks
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
req_to_block_hashes
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
req_to_block_hashes
)
==
0
assert
len
(
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
scheduler
.
kv_cache_manager
.
single_type_manager
.
num_cached_block
)
==
0
num_cached_block
)
==
0
num_free_blocks
=
(
num_free_blocks
=
(
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
.
num_free_blocks
)
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
.
num_free_blocks
)
assert
num_free_blocks
==
(
assert
num_free_blocks
==
(
...
...
tests/v1/core/test_specialized_manager.py
View file @
f8a1a2d1
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
import
torch
import
torch
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
KVCacheBlock
)
from
vllm.v1.core.single_type_kv_cache_manager
import
SlidingWindowManager
from
vllm.v1.core.single_type_kv_cache_manager
import
SlidingWindowManager
from
vllm.v1.kv_cache_interface
import
SlidingWindowSpec
from
vllm.v1.kv_cache_interface
import
SlidingWindowSpec
...
@@ -12,9 +13,8 @@ from vllm.v1.kv_cache_interface import SlidingWindowSpec
...
@@ -12,9 +13,8 @@ from vllm.v1.kv_cache_interface import SlidingWindowSpec
def
get_sliding_window_manager
(
sliding_window_spec
,
block_pool
):
def
get_sliding_window_manager
(
sliding_window_spec
,
block_pool
):
return
SlidingWindowManager
(
sliding_window_spec
,
return
SlidingWindowManager
(
sliding_window_spec
,
block_pool
,
block_pool
,
use_eagle
=
False
,
caching_hash_fn
=
lambda
x
:
x
,
num_kv_cache_groups
=
1
,
kv_cache_group_id
=
0
)
caching_hash_fn
=
lambda
x
:
x
)
def
test_sliding_window_possible_cached_prefix
():
def
test_sliding_window_possible_cached_prefix
():
...
@@ -42,13 +42,18 @@ def test_sliding_window_possible_cached_prefix():
...
@@ -42,13 +42,18 @@ def test_sliding_window_possible_cached_prefix():
for
i
,
(
block_hash
,
for
i
,
(
block_hash
,
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
if
is_cached
:
if
is_cached
:
block_pool
.
cached_block_hash_to_block
[
block_hash
]
=
{
block_pool
.
cached_block_hash_to_block
[
BlockHashWithGroupId
(
i
:
block_pool
.
blocks
[
i
+
10
]
block_hash
,
0
)]
=
{
}
i
:
block_pool
.
blocks
[
i
+
10
],
}
computed_blocks
=
manager
.
find_longest_cache_hit
(
computed_blocks
=
manager
.
find_longest_cache_hit
(
block_hash_list
,
block_hashes
=
block_hash_list
,
len
(
block_hash_list
)
*
block_size
)
max_length
=
len
(
block_hash_list
)
*
block_size
,
kv_cache_group_ids
=
[
0
],
block_pool
=
block_pool
,
kv_cache_spec
=
sliding_window_spec
,
use_eagle
=
False
)[
0
]
assert
len
(
computed_blocks
)
==
expect_length
assert
len
(
computed_blocks
)
==
expect_length
assert
all
(
block
==
block_pool
.
null_block
assert
all
(
block
==
block_pool
.
null_block
...
@@ -95,13 +100,13 @@ def test_sliding_window_remove_skipped_blocks():
...
@@ -95,13 +100,13 @@ def test_sliding_window_remove_skipped_blocks():
null_block_id
=
block_pool
.
null_block
.
block_id
null_block_id
=
block_pool
.
null_block
.
block_id
def
id_to_block_table
(
ids
):
def
id_to_block_table
(
ids
)
->
list
[
KVCacheBlock
]
:
return
[
return
[
KVCacheBlock
(
id_
)
KVCacheBlock
(
id_
)
if
id_
!=
null_block_id
else
block_pool
.
null_block
for
id_
in
ids
if
id_
!=
null_block_id
else
block_pool
.
null_block
for
id_
in
ids
]
]
def
assert_block_id
(
block_table
,
ids
):
def
assert_block_id
(
block_table
:
list
[
KVCacheBlock
],
ids
:
list
[
int
]
):
for
block
,
id_
in
zip
(
block_table
,
ids
):
for
block
,
id_
in
zip
(
block_table
,
ids
):
if
id_
==
null_block_id
:
if
id_
==
null_block_id
:
assert
block
==
block_pool
.
null_block
assert
block
==
block_pool
.
null_block
...
...
tests/v1/e2e/test_correctness_sliding_window.py
View file @
f8a1a2d1
...
@@ -18,7 +18,7 @@ class TestConfig:
...
@@ -18,7 +18,7 @@ class TestConfig:
model_config
=
{
model_config
=
{
"bigcode/starcoder2-3b"
:
TestConfig
(
4096
,
(
800
,
1100
)),
"bigcode/starcoder2-3b"
:
TestConfig
(
4096
,
(
800
,
1100
)),
"google/gemma-
2-2
b-it"
:
TestConfig
(
4096
,
(
400
,
800
)),
"google/gemma-
3-1
b-it"
:
TestConfig
(
4096
,
(
400
,
800
)),
}
}
...
@@ -26,7 +26,7 @@ model_config = {
...
@@ -26,7 +26,7 @@ model_config = {
"model"
,
"model"
,
[
[
"bigcode/starcoder2-3b"
,
# sliding window only
"bigcode/starcoder2-3b"
,
# sliding window only
"google/gemma-
2-2
b-it"
,
# sliding window + full attention
"google/gemma-
3-1
b-it"
,
# sliding window + full attention
])
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
f8a1a2d1
...
@@ -36,8 +36,8 @@ def test_basic_inferface():
...
@@ -36,8 +36,8 @@ def test_basic_inferface():
req_meta
=
kv_connector_metadata
.
requests
[
request_id
]
req_meta
=
kv_connector_metadata
.
requests
[
request_id
]
for
block_id
,
block
in
zip
(
for
block_id
,
block
in
zip
(
req_meta
.
local_block_ids
,
scheduler
.
kv_cache_manager
.
req_meta
.
local_block_ids
,
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_manager
.
req_to_blocks
[
request_id
]):
single_type_manager
s
[
0
]
.
req_to_blocks
[
request_id
]):
assert
block_id
==
block
.
block_id
assert
block_id
==
block
.
block_id
...
...
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
f8a1a2d1
...
@@ -54,8 +54,8 @@ def test_basic_lifecycle():
...
@@ -54,8 +54,8 @@ def test_basic_lifecycle():
assert
len
(
scheduler
.
waiting
)
==
0
assert
len
(
scheduler
.
waiting
)
==
0
# ... but blocks should not be freed.
# ... but blocks should not be freed.
blocks
=
scheduler
.
kv_cache_manager
.
single_type_manager
.
req_to_block
s
[
blocks
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
request_id
]
0
].
req_to_blocks
[
request_id
]
for
block
in
blocks
:
for
block
in
blocks
:
assert
block
.
ref_cnt
==
1
assert
block
.
ref_cnt
==
1
...
...
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
View file @
f8a1a2d1
...
@@ -51,8 +51,8 @@ def test_basic_lifecycle():
...
@@ -51,8 +51,8 @@ def test_basic_lifecycle():
assert
(
block_pool
.
free_block_queue
.
num_free_blocks
assert
(
block_pool
.
free_block_queue
.
num_free_blocks
<
START_FREE_BLOCK_QUEUE_SIZE
)
<
START_FREE_BLOCK_QUEUE_SIZE
)
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
0
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
0
blocks
=
scheduler
.
kv_cache_manager
.
single_type_manager
.
req_to_block
s
[
blocks
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
request_id
]
0
].
req_to_blocks
[
request_id
]
for
block
in
blocks
:
for
block
in
blocks
:
assert
block
.
_block_hash
is
None
assert
block
.
_block_hash
is
None
...
@@ -87,8 +87,8 @@ def test_basic_lifecycle():
...
@@ -87,8 +87,8 @@ def test_basic_lifecycle():
# Confirm the block are actually allocated.
# Confirm the block are actually allocated.
num_hashed_blocks
=
0
num_hashed_blocks
=
0
blocks
=
scheduler
.
kv_cache_manager
.
single_type_manager
.
req_to_block
s
[
blocks
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
request_id
]
0
].
req_to_blocks
[
request_id
]
for
block
in
blocks
:
for
block
in
blocks
:
assert
block
.
ref_cnt
==
1
assert
block
.
ref_cnt
==
1
num_hashed_blocks
+=
(
1
if
block
.
_block_hash
is
not
None
else
0
)
num_hashed_blocks
+=
(
1
if
block
.
_block_hash
is
not
None
else
0
)
...
@@ -261,10 +261,10 @@ def test_no_spurious_prefix_caching():
...
@@ -261,10 +261,10 @@ def test_no_spurious_prefix_caching():
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
local_blocks
=
scheduler
.
kv_cache_manager
.
single_type_manager
.
req_to_block
s
[
local_blocks
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
request_local
.
request_id
]
0
].
req_to_blocks
[
request_local
.
request_id
]
remote_blocks
=
scheduler
.
kv_cache_manager
.
single_type_manager
.
req_to_blocks
[
# noqa: E501
remote_blocks
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_manager
s
[
request_remote
.
request_id
]
0
].
req_to_blocks
[
request_remote
.
request_id
]
# Local should have cached blocks (but not all due to preallocate).
# Local should have cached blocks (but not all due to preallocate).
num_hashed_blocks
=
0
num_hashed_blocks
=
0
...
@@ -300,8 +300,8 @@ def test_full_block_prompt():
...
@@ -300,8 +300,8 @@ def test_full_block_prompt():
# STEP (1): Initialize a recv.
# STEP (1): Initialize a recv.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
# All blocks should be allocated.
# All blocks should be allocated.
num_blocks
=
len
(
scheduler
.
kv_cache_manager
.
single_type_manage
r
.
num_blocks
=
len
(
scheduler
.
kv_cache_manager
.
coordinato
r
.
req_to_blocks
[
request_id
])
single_type_managers
[
0
].
req_to_blocks
[
request_id
])
assert
num_blocks
==
NUM_EXTERNAL_FULL_BLOCKS
assert
num_blocks
==
NUM_EXTERNAL_FULL_BLOCKS
model_runner_output
=
EMPTY_MODEL_RUNNER_OUTPUT
model_runner_output
=
EMPTY_MODEL_RUNNER_OUTPUT
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
...
@@ -319,8 +319,8 @@ def test_full_block_prompt():
...
@@ -319,8 +319,8 @@ def test_full_block_prompt():
# We need to recompute the final token of the prompt to generate
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
# the first new token, so we should not have a new block.
num_blocks
=
len
(
scheduler
.
kv_cache_manager
.
single_type_manage
r
.
num_blocks
=
len
(
scheduler
.
kv_cache_manager
.
coordinato
r
.
req_to_blocks
[
request_id
])
single_type_managers
[
0
].
req_to_blocks
[
request_id
])
assert
num_blocks
==
NUM_EXTERNAL_FULL_BLOCKS
assert
num_blocks
==
NUM_EXTERNAL_FULL_BLOCKS
assert
(
scheduler_output
.
scheduled_new_reqs
[
0
].
num_computed_tokens
==
assert
(
scheduler_output
.
scheduled_new_reqs
[
0
].
num_computed_tokens
==
NUM_TOKENS
-
1
)
NUM_TOKENS
-
1
)
...
...
tests/v1/kv_connector/unit/utils.py
View file @
f8a1a2d1
...
@@ -32,11 +32,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
...
@@ -32,11 +32,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert
len
(
scheduler
.
encoder_cache_manager
.
cached
)
==
0
assert
len
(
scheduler
.
encoder_cache_manager
.
cached
)
==
0
# KVCache Manager.
# KVCache Manager.
assert
len
(
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
scheduler
.
kv_cache_manager
.
single_type_manager
.
req_to_blocks
)
==
0
req_to_blocks
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
req_to_block_hashes
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
req_to_block_hashes
)
==
0
assert
len
(
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
scheduler
.
kv_cache_manager
.
single_type_manager
.
num_cached_block
)
==
0
num_cached_block
)
==
0
num_free_blocks
=
(
num_free_blocks
=
(
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
.
num_free_blocks
)
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
.
num_free_blocks
)
assert
num_free_blocks
==
(
assert
num_free_blocks
==
(
...
@@ -96,7 +96,7 @@ def create_scheduler(
...
@@ -96,7 +96,7 @@ def create_scheduler(
block_size
=
vllm_config
.
cache_config
.
block_size
block_size
=
vllm_config
.
cache_config
.
block_size
kv_cache_config
=
KVCacheConfig
(
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
num_blocks
=
num_blocks
,
# A large number of blocks to hold all requests
tensors
=
{}
,
kv_cache_
tensors
=
[]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
([
'layer'
],
KVCacheGroupSpec
([
'layer'
],
FullAttentionSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
FullAttentionSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
f8a1a2d1
...
@@ -40,12 +40,13 @@ def initialize_kv_cache(runner: GPUModelRunner):
...
@@ -40,12 +40,13 @@ def initialize_kv_cache(runner: GPUModelRunner):
tensor_size
=
attn_spec
.
page_size_bytes
*
NUM_BLOCKS
tensor_size
=
attn_spec
.
page_size_bytes
*
NUM_BLOCKS
kv_cache_config
=
KVCacheConfig
(
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
NUM_BLOCKS
,
num_blocks
=
NUM_BLOCKS
,
tensors
=
{
kv_cache_
tensors
=
[
"layer.0"
:
KVCacheTensor
(
size
=
tensor_size
),
KVCacheTensor
(
size
=
tensor_size
,
shared_by
=
[
"layer.0"
]
),
}
,
]
,
kv_cache_groups
=
[
kv_cache_groups
=
[
KVCacheGroupSpec
(
layer_names
=
[
"layer.0"
],
kv_cache_spec
=
attn_spec
)
KVCacheGroupSpec
(
layer_names
=
[
"layer.0"
],
kv_cache_spec
=
attn_spec
)
])
],
)
runner
.
kv_cache_config
=
kv_cache_config
runner
.
kv_cache_config
=
kv_cache_config
runner
.
input_batch
=
InputBatch
(
runner
.
input_batch
=
InputBatch
(
max_num_reqs
=
runner
.
max_num_reqs
,
max_num_reqs
=
runner
.
max_num_reqs
,
...
@@ -518,9 +519,9 @@ def test_init_kv_cache_without_kv_sharing():
...
@@ -518,9 +519,9 @@ def test_init_kv_cache_without_kv_sharing():
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
available_memory
)
available_memory
)
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
len
(
kv_cache_config
.
tensors
)
==
2
assert
len
(
kv_cache_config
.
kv_cache_
tensors
)
==
2
assert
kv_cache_config
.
tensors
[
layer_
0
].
size
==
available_memory
//
2
assert
kv_cache_config
.
kv_cache_
tensors
[
0
].
size
==
available_memory
//
2
assert
kv_cache_config
.
tensors
[
layer_
1
].
size
==
available_memory
//
2
assert
kv_cache_config
.
kv_cache_
tensors
[
1
].
size
==
available_memory
//
2
max_context_len
=
\
max_context_len
=
\
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
...
@@ -530,9 +531,9 @@ def test_init_kv_cache_without_kv_sharing():
...
@@ -530,9 +531,9 @@ def test_init_kv_cache_without_kv_sharing():
# important: override tensor size to prevent large mem alloc during test
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
# this will only allocate 2 block worth of memory (2 * 32kb)
kv_cache_config
.
num_blocks
=
1
kv_cache_config
.
num_blocks
=
1
for
laye
r
in
kv_cache_config
.
tensors
:
for
kv_cache_tenso
r
in
kv_cache_config
.
kv_cache_
tensors
:
kv_cache_
config
.
tensors
[
layer
]
.
size
=
\
kv_cache_
tensor
.
size
=
(
kv_cache_spec
[
layer
].
page_size_bytes
kv_cache_spec
[
kv_cache_tensor
.
shared_by
[
0
]
].
page_size_bytes
)
runner
.
initialize_kv_cache
(
kv_cache_config
)
runner
.
initialize_kv_cache
(
kv_cache_config
)
...
@@ -589,10 +590,10 @@ def test_init_kv_cache_with_kv_sharing_valid():
...
@@ -589,10 +590,10 @@ def test_init_kv_cache_with_kv_sharing_valid():
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
available_memory
)
available_memory
)
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
kv_cache_config
.
num_blocks
==
num_expected_blocks
assert
len
(
kv_cache_config
.
tensors
)
==
1
assert
len
(
kv_cache_config
.
kv_cache_
tensors
)
==
1
# Each layer now has twice the available memory for KV cache
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
# compared to no KV sharing
assert
kv_cache_config
.
tensors
[
layer_
0
].
size
==
available_memory
assert
kv_cache_config
.
kv_cache_
tensors
[
0
].
size
==
available_memory
max_context_len
=
\
max_context_len
=
\
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
5
*
GiB_bytes
)
...
@@ -602,7 +603,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
...
@@ -602,7 +603,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
# important: override tensor size to prevent large mem alloc during test
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
# this will only allocate 1 block worth of memory (32kb)
kv_cache_config
.
num_blocks
=
1
kv_cache_config
.
num_blocks
=
1
kv_cache_config
.
tensors
[
layer_
0
].
size
=
\
kv_cache_config
.
kv_cache_
tensors
[
0
].
size
=
\
kv_cache_spec
[
layer_0
].
page_size_bytes
kv_cache_spec
[
layer_0
].
page_size_bytes
runner
.
initialize_kv_cache
(
kv_cache_config
)
runner
.
initialize_kv_cache
(
kv_cache_config
)
...
...
vllm/config.py
View file @
f8a1a2d1
...
@@ -2104,6 +2104,12 @@ class SchedulerConfig:
...
@@ -2104,6 +2104,12 @@ class SchedulerConfig:
default scheduler. Can be a class directly or the path to a class of form
default scheduler. Can be a class directly or the path to a class of form
"mod.custom_class"."""
"mod.custom_class"."""
disable_hybrid_kv_cache_manager
:
bool
=
False
"""If set to True, KV cache manager will allocate the same size of KV cache
for all attention layers even if there are multiple type of attention layers
like full attention and sliding window attention.
"""
def
compute_hash
(
self
)
->
str
:
def
compute_hash
(
self
)
->
str
:
"""
"""
WARNING: Whenever a new field is added to this config,
WARNING: Whenever a new field is added to this config,
...
@@ -4465,6 +4471,21 @@ class VllmConfig:
...
@@ -4465,6 +4471,21 @@ class VllmConfig:
if
not
self
.
instance_id
:
if
not
self
.
instance_id
:
self
.
instance_id
=
random_uuid
()[:
5
]
self
.
instance_id
=
random_uuid
()[:
5
]
if
(
envs
.
VLLM_USE_V1
and
not
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
):
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if
not
(
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
()):
# Hybrid KV cache manager is not supported on non-GPU platforms.
self
.
disable_hybrid_kv_cache_manager
=
True
if
self
.
kv_transfer_config
is
not
None
:
# Hybrid KV cache manager is not compatible with KV transfer.
self
.
disable_hybrid_kv_cache_manager
=
True
if
self
.
kv_events_config
is
not
None
:
# Hybrid KV cache manager is not compatible with KV events.
self
.
disable_hybrid_kv_cache_manager
=
True
def
update_sizes_for_sequence_parallelism
(
self
,
def
update_sizes_for_sequence_parallelism
(
self
,
possible_sizes
:
list
)
->
list
:
possible_sizes
:
list
)
->
list
:
# remove the sizes that not multiple of tp_size when
# remove the sizes that not multiple of tp_size when
...
...
vllm/engine/arg_utils.py
View file @
f8a1a2d1
...
@@ -387,6 +387,9 @@ class EngineArgs:
...
@@ -387,6 +387,9 @@ class EngineArgs:
bool
]
=
SchedulerConfig
.
enable_chunked_prefill
bool
]
=
SchedulerConfig
.
enable_chunked_prefill
disable_chunked_mm_input
:
bool
=
SchedulerConfig
.
disable_chunked_mm_input
disable_chunked_mm_input
:
bool
=
SchedulerConfig
.
disable_chunked_mm_input
disable_hybrid_kv_cache_manager
:
bool
=
(
SchedulerConfig
.
disable_hybrid_kv_cache_manager
)
guided_decoding_backend
:
GuidedDecodingBackend
=
DecodingConfig
.
backend
guided_decoding_backend
:
GuidedDecodingBackend
=
DecodingConfig
.
backend
guided_decoding_disable_fallback
:
bool
=
DecodingConfig
.
disable_fallback
guided_decoding_disable_fallback
:
bool
=
DecodingConfig
.
disable_fallback
guided_decoding_disable_any_whitespace
:
bool
=
\
guided_decoding_disable_any_whitespace
:
bool
=
\
...
@@ -849,6 +852,9 @@ class EngineArgs:
...
@@ -849,6 +852,9 @@ class EngineArgs:
**
scheduler_kwargs
[
"disable_chunked_mm_input"
])
**
scheduler_kwargs
[
"disable_chunked_mm_input"
])
scheduler_group
.
add_argument
(
"--scheduler-cls"
,
scheduler_group
.
add_argument
(
"--scheduler-cls"
,
**
scheduler_kwargs
[
"scheduler_cls"
])
**
scheduler_kwargs
[
"scheduler_cls"
])
scheduler_group
.
add_argument
(
"--disable-hybrid-kv-cache-manager"
,
**
scheduler_kwargs
[
"disable_hybrid_kv_cache_manager"
])
# vLLM arguments
# vLLM arguments
vllm_kwargs
=
get_kwargs
(
VllmConfig
)
vllm_kwargs
=
get_kwargs
(
VllmConfig
)
...
@@ -1174,6 +1180,8 @@ class EngineArgs:
...
@@ -1174,6 +1180,8 @@ class EngineArgs:
max_num_partial_prefills
=
self
.
max_num_partial_prefills
,
max_num_partial_prefills
=
self
.
max_num_partial_prefills
,
max_long_partial_prefills
=
self
.
max_long_partial_prefills
,
max_long_partial_prefills
=
self
.
max_long_partial_prefills
,
long_prefill_token_threshold
=
self
.
long_prefill_token_threshold
,
long_prefill_token_threshold
=
self
.
long_prefill_token_threshold
,
disable_hybrid_kv_cache_manager
=
self
.
disable_hybrid_kv_cache_manager
,
)
)
lora_config
=
LoRAConfig
(
lora_config
=
LoRAConfig
(
...
...
vllm/v1/core/block_pool.py
View file @
f8a1a2d1
...
@@ -7,8 +7,8 @@ from typing import Callable, Optional
...
@@ -7,8 +7,8 @@ from typing import Callable, Optional
from
vllm.distributed.kv_events
import
(
AllBlocksCleared
,
BlockRemoved
,
from
vllm.distributed.kv_events
import
(
AllBlocksCleared
,
BlockRemoved
,
BlockStored
,
KVCacheEvent
)
BlockStored
,
KVCacheEvent
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
FreeKVCacheBlockQueue
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
KVCacheBlock
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
generate_block_hash_extra_keys
,
generate_block_hash_extra_keys
,
hash_block_tokens
)
hash_block_tokens
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -27,6 +27,7 @@ class BlockPool:
...
@@ -27,6 +27,7 @@ class BlockPool:
Args:
Args:
num_gpu_blocks: The number of blocks in the pool.
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
enable_caching: Whether to enable prefix caching.
enable_kv_cache_events: Whether to enable kv cache events.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -56,7 +57,7 @@ class BlockPool:
...
@@ -56,7 +57,7 @@ class BlockPool:
# if there is already an identical block in the cache. This is because
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
# block tables are append-only.
self
.
cached_block_hash_to_block
:
dict
[
BlockHash
,
dict
[
self
.
cached_block_hash_to_block
:
dict
[
BlockHash
WithGroupId
,
dict
[
int
,
KVCacheBlock
]]
=
defaultdict
(
dict
)
int
,
KVCacheBlock
]]
=
defaultdict
(
dict
)
# To represent a placeholder block with block_id=0.
# To represent a placeholder block with block_id=0.
...
@@ -68,22 +69,29 @@ class BlockPool:
...
@@ -68,22 +69,29 @@ class BlockPool:
self
.
enable_kv_cache_events
=
enable_kv_cache_events
self
.
enable_kv_cache_events
=
enable_kv_cache_events
self
.
kv_event_queue
:
list
[
KVCacheEvent
]
=
[]
self
.
kv_event_queue
:
list
[
KVCacheEvent
]
=
[]
def
get_cached_block
(
self
,
def
get_cached_block
(
block_hash
:
BlockHash
)
->
Optional
[
KVCacheBlock
]:
self
,
block_hash
:
BlockHash
,
"""Get a cached block by the block hash, or None if cache miss.
kv_cache_group_ids
:
list
[
int
])
->
Optional
[
list
[
KVCacheBlock
]]:
"""Get the cached block by the block hash for each group in
`kv_cache_group_ids`, or None if cache miss for any group.
If there are duplicated blocks, we return the first block in the cache.
If there are duplicated blocks, we return the first block in the cache.
Args:
Args:
block_hash: The hash value of the block.
block_hash: The hash value of the block.
kv_cache_group_ids: The ids of the KV cache groups.
Returns:
Returns:
The cached block if
it
exists, or None.
The cached block
s
if exists, or None.
"""
"""
cached_blocks
=
self
.
cached_block_hash_to_block
.
get
(
block_hash
)
cached_blocks
=
[]
if
not
cached_blocks
:
for
group_id
in
kv_cache_group_ids
:
return
None
cached_blocks_one_group
=
self
.
cached_block_hash_to_block
.
get
(
first_block_id
=
next
(
iter
(
cached_blocks
))
BlockHashWithGroupId
(
block_hash
,
group_id
))
return
cached_blocks
[
first_block_id
]
if
not
cached_blocks_one_group
:
return
None
first_block_id
=
next
(
iter
(
cached_blocks_one_group
))
cached_blocks
.
append
(
cached_blocks_one_group
[
first_block_id
])
return
cached_blocks
def
cache_full_blocks
(
def
cache_full_blocks
(
self
,
self
,
...
@@ -93,6 +101,7 @@ class BlockPool:
...
@@ -93,6 +101,7 @@ class BlockPool:
num_cached_blocks
:
int
,
num_cached_blocks
:
int
,
num_full_blocks
:
int
,
num_full_blocks
:
int
,
block_size
:
int
,
block_size
:
int
,
kv_cache_group_id
:
int
,
hash_fn
:
Callable
,
hash_fn
:
Callable
,
)
->
None
:
)
->
None
:
"""Cache a list of full blocks for prefix caching.
"""Cache a list of full blocks for prefix caching.
...
@@ -112,6 +121,7 @@ class BlockPool:
...
@@ -112,6 +121,7 @@ class BlockPool:
num_full_blocks: The number of blocks that are full and should
num_full_blocks: The number of blocks that are full and should
be cached after this function.
be cached after this function.
block_size: Number of tokens in each block.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
hash_fn: The hash function to use for block hashes.
"""
"""
if
num_cached_blocks
==
num_full_blocks
:
if
num_cached_blocks
==
num_full_blocks
:
...
@@ -126,7 +136,7 @@ class BlockPool:
...
@@ -126,7 +136,7 @@ class BlockPool:
else
:
else
:
prev_block
=
blocks
[
num_cached_blocks
-
1
]
prev_block
=
blocks
[
num_cached_blocks
-
1
]
assert
prev_block
.
block_hash
is
not
None
assert
prev_block
.
block_hash
is
not
None
prev_block_hash_value
=
prev_block
.
block_hash
.
hash_value
prev_block_hash_value
=
prev_block
.
block_hash
.
get_
hash_value
()
parent_block_hash
=
prev_block_hash_value
parent_block_hash
=
prev_block_hash_value
new_hashes
:
Optional
[
list
[
int
]]
=
([]
if
self
.
enable_kv_cache_events
new_hashes
:
Optional
[
list
[
int
]]
=
([]
if
self
.
enable_kv_cache_events
...
@@ -138,8 +148,9 @@ class BlockPool:
...
@@ -138,8 +148,9 @@ class BlockPool:
# The block hash may already be computed in
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# generated tokens with preemption), or by other
# reuse the block hash.
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash
=
new_block_hashes
[
i
]
block_hash
=
new_block_hashes
[
i
]
else
:
else
:
# Otherwise compute the block hash and cache it in the request
# Otherwise compute the block hash and cache it in the request
...
@@ -166,8 +177,11 @@ class BlockPool:
...
@@ -166,8 +177,11 @@ class BlockPool:
block_hashes
.
append
(
block_hash
)
block_hashes
.
append
(
block_hash
)
# Update and added the full block to the cache.
# Update and added the full block to the cache.
blk
.
block_hash
=
block_hash
block_hash_with_group_id
=
BlockHashWithGroupId
(
self
.
cached_block_hash_to_block
[
block_hash
][
blk
.
block_id
]
=
blk
block_hash
,
kv_cache_group_id
)
blk
.
block_hash
=
block_hash_with_group_id
self
.
cached_block_hash_to_block
[
block_hash_with_group_id
][
blk
.
block_id
]
=
blk
if
new_hashes
is
not
None
:
if
new_hashes
is
not
None
:
new_hashes
.
append
(
block_hash
.
hash_value
)
new_hashes
.
append
(
block_hash
.
hash_value
)
prev_block_hash_value
=
block_hash
.
hash_value
prev_block_hash_value
=
block_hash
.
hash_value
...
@@ -237,12 +251,16 @@ class BlockPool:
...
@@ -237,12 +251,16 @@ class BlockPool:
del
self
.
cached_block_hash_to_block
[
block_hash
]
del
self
.
cached_block_hash_to_block
[
block_hash
]
if
self
.
enable_kv_cache_events
:
if
self
.
enable_kv_cache_events
:
# FIXME (Chen): Not sure whether we should return `hash_value`
# or `(hash_value, group_id)` here. But it's fine now because
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self
.
kv_event_queue
.
append
(
self
.
kv_event_queue
.
append
(
BlockRemoved
(
block_hashes
=
[
block_hash
.
hash_value
]))
BlockRemoved
(
block_hashes
=
[
block_hash
.
get_
hash_value
()
]))
return
True
return
True
return
False
return
False
def
touch
(
self
,
blocks
:
list
[
KVCacheBlock
])
->
None
:
def
touch
(
self
,
blocks
:
list
[
list
[
KVCacheBlock
]
]
)
->
None
:
"""Touch a block increases its reference count by 1, and may remove
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
another request with the same prefix.
...
@@ -250,12 +268,13 @@ class BlockPool:
...
@@ -250,12 +268,13 @@ class BlockPool:
Args:
Args:
blocks: A list of blocks to touch.
blocks: A list of blocks to touch.
"""
"""
for
block
in
blocks
:
for
blocks_per_group
in
blocks
:
# ref_cnt=0 means this block is in the free list (i.e. eviction
for
block
in
blocks_per_group
:
# candidate), so remove it.
# ref_cnt=0 means this block is in the free list (i.e. eviction
if
block
.
ref_cnt
==
0
and
not
block
.
is_null
:
# candidate), so remove it.
self
.
free_block_queue
.
remove
(
block
)
if
block
.
ref_cnt
==
0
and
not
block
.
is_null
:
block
.
incr_ref
()
self
.
free_block_queue
.
remove
(
block
)
block
.
incr_ref
()
def
free_blocks
(
self
,
ordered_blocks
:
Iterable
[
KVCacheBlock
])
->
None
:
def
free_blocks
(
self
,
ordered_blocks
:
Iterable
[
KVCacheBlock
])
->
None
:
"""Free a list of blocks. The blocks should be ordered by their
"""Free a list of blocks. The blocks should be ordered by their
...
...
vllm/v1/core/kv_cache_coordinator.py
0 → 100644
View file @
f8a1a2d1
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
,
Optional
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
SingleTypeKVCacheManager
,
get_manager_for_kv_cache_spec
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
KVCacheConfig
from
vllm.v1.request
import
Request
class
KVCacheCoordinator
(
ABC
):
"""
Coordinate the KV cache of different KV cache groups.
"""
def
__init__
(
self
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
use_eagle
:
bool
,
enable_caching
:
bool
,
caching_hash_fn
:
Callable
,
enable_kv_cache_events
:
bool
,
):
self
.
kv_cache_config
=
kv_cache_config
self
.
max_model_len
=
max_model_len
self
.
block_pool
=
BlockPool
(
kv_cache_config
.
num_blocks
,
enable_caching
,
enable_kv_cache_events
)
self
.
single_type_managers
:
list
[
SingleTypeKVCacheManager
]
=
[]
# Needs special handling for find_longest_cache_hit if eagle is enabled
self
.
use_eagle
=
use_eagle
for
i
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
)):
kv_cache_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
i
].
kv_cache_spec
self
.
single_type_managers
.
append
(
get_manager_for_kv_cache_spec
(
kv_cache_spec
=
kv_cache_spec
,
block_pool
=
self
.
block_pool
,
kv_cache_group_id
=
i
,
caching_hash_fn
=
caching_hash_fn
,
))
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
new_computed_blocks
:
list
[
list
[
KVCacheBlock
]])
->
int
:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_blocks_to_allocate
=
0
for
i
,
manager
in
enumerate
(
self
.
single_type_managers
):
num_blocks_to_allocate
+=
manager
.
get_num_blocks_to_allocate
(
request_id
,
num_tokens
,
new_computed_blocks
[
i
])
return
num_blocks_to_allocate
def
save_new_computed_blocks
(
self
,
request_id
:
str
,
new_computed_blocks
:
list
[
list
[
KVCacheBlock
]])
->
None
:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
for
i
,
manager
in
enumerate
(
self
.
single_type_managers
):
manager
.
save_new_computed_blocks
(
request_id
,
new_computed_blocks
[
i
])
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
list
[
list
[
KVCacheBlock
]]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
new_blocks
=
[]
for
manager
in
self
.
single_type_managers
:
new_blocks
.
append
(
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
))
return
new_blocks
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
num_computed_tokens
:
int
)
->
None
:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
for
manager
in
self
.
single_type_managers
:
manager
.
cache_blocks
(
request
,
block_hashes
,
num_computed_tokens
)
def
free
(
self
,
request_id
:
str
)
->
None
:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
for
manager
in
self
.
single_type_managers
:
manager
.
free
(
request_id
)
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
num_running_requests
:
int
)
->
list
[
int
]:
"""
Get the number of common prefix blocks for a request.
Args:
request_id: The request ID.
block_hashes: The block hashes of the request.
Returns:
The number of common prefix blocks.
"""
num_blocks_per_group
=
[
manager
.
get_num_common_prefix_blocks
(
request_id
,
num_running_requests
)
for
manager
in
self
.
single_type_managers
]
return
num_blocks_per_group
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
num_computed_tokens
:
int
)
->
None
:
"""
Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
for
manager
in
self
.
single_type_managers
:
manager
.
remove_skipped_blocks
(
request_id
,
num_computed_tokens
)
def
get_blocks
(
self
,
request_id
:
str
)
->
list
[
list
[
KVCacheBlock
]]:
"""
Get the blocks for the request.
"""
return
[
manager
.
req_to_blocks
[
request_id
]
for
manager
in
self
.
single_type_managers
]
@
abstractmethod
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
pass
class
UnitaryKVCacheCoordinator
(
KVCacheCoordinator
):
"""
KV cache coordinator for models with only one KV cache group. This is the
case for models with only one KV cache type, e.g., all attention layers use
full attention or all attention layers use sliding window attention.
"""
def
__init__
(
self
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
use_eagle
:
bool
,
enable_caching
:
bool
,
caching_hash_fn
:
Callable
,
enable_kv_cache_events
:
bool
):
super
().
__init__
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
self
.
kv_cache_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
self
.
block_size
=
self
.
kv_cache_spec
.
block_size
assert
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
1
,
(
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
hit_blocks
=
self
.
single_type_managers
[
0
].
find_longest_cache_hit
(
block_hashes
=
block_hashes
,
max_length
=
max_cache_hit_length
,
kv_cache_group_ids
=
[
0
],
block_pool
=
self
.
block_pool
,
kv_cache_spec
=
self
.
kv_cache_spec
,
use_eagle
=
self
.
use_eagle
,
)
return
hit_blocks
,
len
(
hit_blocks
[
0
])
*
self
.
block_size
class
HybridKVCacheCoordinator
(
KVCacheCoordinator
):
"""
KV cache coordinator for hybrid models with multiple KV cache types, and
thus multiple kv cache groups.
To simplify `find_longest_cache_hit`, it only supports the combination of
two types of KV cache groups, and one of them must be full attention.
May extend to more general cases in the future.
"""
def
__init__
(
self
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
use_eagle
:
bool
,
enable_caching
:
bool
,
caching_hash_fn
:
Callable
,
enable_kv_cache_events
:
bool
):
super
().
__init__
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
self
.
verify_and_split_kv_cache_groups
()
def
verify_and_split_kv_cache_groups
(
self
)
->
None
:
"""
Verifies that the model has exactly two types of KV cache groups, and
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
"""
full_attention_type_id
:
Optional
[
str
]
=
None
other_type_id
:
Optional
[
str
]
=
None
self
.
full_attention_group_ids
:
list
[
int
]
=
[]
self
.
other_group_ids
:
list
[
int
]
=
[]
for
i
,
g
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
if
isinstance
(
g
.
kv_cache_spec
,
FullAttentionSpec
):
if
full_attention_type_id
is
None
:
full_attention_type_id
=
g
.
kv_cache_spec
.
type_id
else
:
assert
full_attention_type_id
==
g
.
kv_cache_spec
.
type_id
,
(
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now."
)
self
.
full_attention_group_ids
.
append
(
i
)
else
:
if
other_type_id
is
None
:
other_type_id
=
g
.
kv_cache_spec
.
type_id
else
:
assert
other_type_id
==
g
.
kv_cache_spec
.
type_id
,
(
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now."
)
self
.
other_group_ids
.
append
(
i
)
assert
full_attention_type_id
is
not
None
,
(
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now."
)
assert
other_type_id
is
not
None
,
(
"HybridKVCacheCoordinator assumes exactly one type of other "
"groups now."
)
self
.
full_attention_manager_cls
=
FullAttentionManager
self
.
other_attention_cls
=
self
.
single_type_managers
[
self
.
other_group_ids
[
0
]].
__class__
self
.
full_attention_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
self
.
full_attention_group_ids
[
0
]].
kv_cache_spec
self
.
other_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
self
.
other_group_ids
[
0
]].
kv_cache_spec
self
.
full_attention_block_size
=
self
.
full_attention_spec
.
block_size
self
.
other_block_size
=
self
.
other_spec
.
block_size
assert
self
.
other_block_size
%
self
.
full_attention_block_size
==
0
,
(
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now."
)
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_cache_hit_length
:
int
,
)
->
tuple
[
list
[
list
[
KVCacheBlock
]],
int
]:
"""
Find the longest cache hit for the request.
Args:
block_hashes: The block hashes of the request.
max_cache_hit_length: The maximum length of the cache hit.
Returns:
A tuple containing:
- A list of the cache hit blocks for each single type manager.
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
hit_blocks_full_attn
=
(
self
.
full_attention_manager_cls
.
find_longest_cache_hit
(
block_hashes
=
block_hashes
,
max_length
=
max_cache_hit_length
,
kv_cache_group_ids
=
self
.
full_attention_group_ids
,
block_pool
=
self
.
block_pool
,
kv_cache_spec
=
self
.
full_attention_spec
,
use_eagle
=
self
.
use_eagle
,
))
hit_length
=
len
(
hit_blocks_full_attn
[
0
])
*
self
.
full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
hit_blocks_other_attn
=
(
self
.
other_attention_cls
.
find_longest_cache_hit
(
block_hashes
=
block_hashes
,
max_length
=
hit_length
,
kv_cache_group_ids
=
self
.
other_group_ids
,
block_pool
=
self
.
block_pool
,
kv_cache_spec
=
self
.
other_spec
,
use_eagle
=
self
.
use_eagle
,
))
hit_length
=
len
(
hit_blocks_other_attn
[
0
])
*
self
.
other_block_size
# NOTE: the prefix cache hit length must be a multiply of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiply of the block size of
# full attention layers in current implementation, because hit_length is
# a multiply of other attention's block size, and other attention's
# block size is a multiply of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert
hit_length
%
self
.
full_attention_block_size
==
0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for
i
in
range
(
len
(
hit_blocks_full_attn
)):
del
hit_blocks_full_attn
[
i
][
hit_length
//
self
.
full_attention_block_size
:]
# Merge the hit blocks of full attention and other attention.
hit_blocks
=
hit_blocks_other_attn
for
group_id
,
blocks
in
enumerate
(
hit_blocks_full_attn
):
# NOTE: there is only one full attention group in most cases. So
# the time complexity of insert is fine.
hit_blocks
.
insert
(
group_id
,
blocks
)
return
hit_blocks
,
hit_length
def
get_kv_cache_coordinator
(
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
use_eagle
:
bool
,
enable_caching
:
bool
,
caching_hash_fn
:
Callable
,
enable_kv_cache_events
:
bool
)
->
KVCacheCoordinator
:
if
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
:
return
UnitaryKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
else
:
return
HybridKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
vllm/v1/core/kv_cache_manager.py
View file @
f8a1a2d1
...
@@ -8,11 +8,9 @@ from typing import Optional
...
@@ -8,11 +8,9 @@ from typing import Optional
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
sha256
from
vllm.utils
import
sha256
from
vllm.v1.core.
block_pool
import
BlockPool
from
vllm.v1.core.
kv_cache_coordinator
import
get_kv_cache_coordinator
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
hash_request_tokens
)
hash_request_tokens
)
from
vllm.v1.core.single_type_kv_cache_manager
import
(
get_manager_for_kv_cache_spec
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
...
@@ -22,16 +20,24 @@ logger = init_logger(__name__)
...
@@ -22,16 +20,24 @@ logger = init_logger(__name__)
@
dataclass
@
dataclass
class
KVCacheBlocks
:
class
KVCacheBlocks
:
blocks
:
list
[
KVCacheBlock
]
"""
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks
:
list
[
list
[
KVCacheBlock
]]
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
kv_cache_groups have the same number of blocks, which is true for now but
will be broken if we want to give different block_size to different
kv_cache_groups in the future.
"""
def
__add__
(
self
,
other
:
"KVCacheBlocks"
)
->
"KVCacheBlocks"
:
def
__add__
(
self
,
other
:
"KVCacheBlocks"
)
->
"KVCacheBlocks"
:
"""Adds two KVCacheBlocks instances."""
"""Adds two KVCacheBlocks instances."""
return
KVCacheBlocks
(
self
.
blocks
+
other
.
blocks
)
return
KVCacheBlocks
(
[
blk1
+
blk2
for
blk1
,
blk2
in
zip
(
self
.
blocks
,
other
.
blocks
)])
@
classmethod
def
create_empty
(
cls
)
->
"KVCacheBlocks"
:
"""Creates a new KVCacheBlocks instance with no blocks."""
return
cls
([])
def
get_block_ids
(
self
)
->
list
[
list
[
int
]]:
def
get_block_ids
(
self
)
->
list
[
list
[
int
]]:
"""
"""
...
@@ -39,15 +45,20 @@ class KVCacheBlocks:
...
@@ -39,15 +45,20 @@ class KVCacheBlocks:
Returns:
Returns:
list[list[int]]: A two-level list where
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups
(only 1 group now)
* the outer list corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
* each inner list contains the block_ids of the blocks in that group
"""
"""
return
[[
block
.
block_id
for
block
in
self
.
blocks
]]
block_ids
=
[]
for
group
in
self
.
blocks
:
block_ids
.
append
([
blk
.
block_id
for
blk
in
group
])
return
block_ids
def
get_unhashed_block_ids
(
self
)
->
list
[
int
]:
def
get_unhashed_block_ids
(
self
)
->
list
[
int
]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
assert
len
(
self
.
blocks
)
==
1
,
"Only one group is supported"
return
[
return
[
block
.
block_id
for
block
in
self
.
blocks
if
block
.
block_hash
is
None
block
.
block_id
for
block
in
self
.
blocks
[
0
]
if
block
.
block_hash
is
None
]
]
...
@@ -63,12 +74,6 @@ class KVCacheManager:
...
@@ -63,12 +74,6 @@ class KVCacheManager:
log_stats
:
bool
=
False
,
log_stats
:
bool
=
False
,
enable_kv_cache_events
:
bool
=
False
,
enable_kv_cache_events
:
bool
=
False
,
)
->
None
:
)
->
None
:
assert
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
,
(
"KVCacheManager does not support hybrid models with more than 1 "
"kv cache group"
)
kv_cache_spec
=
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
num_gpu_blocks
=
kv_cache_config
.
num_blocks
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
enable_caching
=
enable_caching
self
.
enable_caching
=
enable_caching
...
@@ -77,17 +82,24 @@ class KVCacheManager:
...
@@ -77,17 +82,24 @@ class KVCacheManager:
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
# FIXME: make prefix cache stats conditional on log_stats
# FIXME: make prefix cache stats conditional on log_stats
self
.
prefix_cache_stats
=
PrefixCacheStats
()
if
log_stats
else
None
self
.
prefix_cache_stats
=
PrefixCacheStats
()
if
log_stats
else
None
assert
len
(
self
.
block_pool
=
BlockPool
(
self
.
num_gpu_blocks
,
enable_caching
,
set
(
g
.
kv_cache_spec
.
block_size
enable_kv_cache_events
)
for
g
in
kv_cache_config
.
kv_cache_groups
)
)
==
1
,
"Only one block size is supported for now"
self
.
single_type_manager
=
get_manager_for_kv_cache_spec
(
self
.
block_size
=
kv_cache_config
.
kv_cache_groups
[
kv_cache_spec
=
kv_cache_spec
,
0
].
kv_cache_spec
.
block_size
block_pool
=
self
.
block_pool
,
self
.
coordinator
=
get_kv_cache_coordinator
(
kv_cache_config
=
kv_cache_config
,
max_model_len
=
self
.
max_model_len
,
use_eagle
=
self
.
use_eagle
,
use_eagle
=
self
.
use_eagle
,
num_kv_cache_groups
=
1
,
enable_caching
=
enable_caching
,
caching_hash_fn
=
self
.
caching_hash_fn
,
caching_hash_fn
=
self
.
caching_hash_fn
,
enable_kv_cache_events
=
enable_kv_cache_events
,
)
)
self
.
num_kv_cache_groups
=
len
(
kv_cache_config
.
kv_cache_groups
)
self
.
block_pool
=
self
.
coordinator
.
block_pool
self
.
kv_cache_config
=
kv_cache_config
# Mapping from request ID to kv block hashes.
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# This is to avoid recomputing the block hashes for each call of
...
@@ -133,7 +145,7 @@ class KVCacheManager:
...
@@ -133,7 +145,7 @@ class KVCacheManager:
# When the request requires prompt logprobs, we skip prefix caching.
# When the request requires prompt logprobs, we skip prefix caching.
if
(
not
self
.
enable_caching
if
(
not
self
.
enable_caching
or
request
.
sampling_params
.
prompt_logprobs
is
not
None
):
or
request
.
sampling_params
.
prompt_logprobs
is
not
None
):
return
KVCacheBlocks
.
create_empty
(),
0
return
self
.
create_empty
_block_list
(),
0
# The block hashes for the request may already be computed
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
# if the scheduler has tried to schedule the request before.
...
@@ -154,20 +166,16 @@ class KVCacheManager:
...
@@ -154,20 +166,16 @@ class KVCacheManager:
# num_computed_tokens to be block-size aligned. Removing this limitation
# num_computed_tokens to be block-size aligned. Removing this limitation
# could slightly improve performance in the future.
# could slightly improve performance in the future.
max_cache_hit_length
=
request
.
num_tokens
-
1
max_cache_hit_length
=
request
.
num_tokens
-
1
computed_blocks
,
num_new_computed_tokens
=
(
computed_blocks
=
self
.
single_type_manager
.
find_longest_cache_hit
(
self
.
coordinator
.
find_longest_cache_hit
(
block_hashes
,
block_hashes
,
max_cache_hit_length
)
max_cache_hit_length
))
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
if
self
.
log_stats
:
if
self
.
log_stats
:
assert
self
.
prefix_cache_stats
is
not
None
assert
self
.
prefix_cache_stats
is
not
None
self
.
prefix_cache_stats
.
queries
+=
request
.
num_tokens
self
.
prefix_cache_stats
.
queries
+=
request
.
num_tokens
self
.
prefix_cache_stats
.
hits
+=
num_computed_tokens
self
.
prefix_cache_stats
.
hits
+=
num_
new_
computed_tokens
return
KVCacheBlocks
(
computed_blocks
),
num_computed_tokens
return
KVCacheBlocks
(
computed_blocks
),
num_
new_
computed_tokens
def
allocate_slots
(
def
allocate_slots
(
self
,
self
,
...
@@ -220,7 +228,9 @@ class KVCacheManager:
...
@@ -220,7 +228,9 @@ class KVCacheManager:
if
new_computed_blocks
is
not
None
:
if
new_computed_blocks
is
not
None
:
new_computed_block_list
=
new_computed_blocks
.
blocks
new_computed_block_list
=
new_computed_blocks
.
blocks
else
:
else
:
new_computed_block_list
=
[]
new_computed_block_list
=
[
[]
for
_
in
range
(
len
(
self
.
kv_cache_config
.
kv_cache_groups
))
]
# Free the blocks that are skipped during the attention computation
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# (e.g., tokens outside the sliding window).
...
@@ -228,8 +238,8 @@ class KVCacheManager:
...
@@ -228,8 +238,8 @@ class KVCacheManager:
# insufficient free blocks.
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
# the number of evicted blocks.
self
.
single_type_manage
r
.
remove_skipped_blocks
(
self
.
coordinato
r
.
remove_skipped_blocks
(
request
.
request_id
,
request
.
request_id
,
request
.
num_computed_tokens
)
request
.
num_computed_tokens
)
# The number of computed tokens is the number of computed tokens plus
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
# the new prefix caching hits
...
@@ -238,12 +248,12 @@ class KVCacheManager:
...
@@ -238,12 +248,12 @@ class KVCacheManager:
num_tokens_need_slot
=
min
(
num_tokens_need_slot
=
min
(
num_computed_tokens
+
num_new_tokens
+
num_lookahead_tokens
,
num_computed_tokens
+
num_new_tokens
+
num_lookahead_tokens
,
self
.
max_model_len
)
self
.
max_model_len
)
num_blocks_to_allocate
=
(
self
.
single_type_manage
r
.
get_num_blocks_to_allocate
(
num_blocks_to_allocate
=
self
.
coordinato
r
.
get_num_blocks_to_allocate
(
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
num_tokens
=
num_tokens_need_slot
,
num_tokens
=
num_tokens_need_slot
,
new_computed_blocks
=
new_computed_block_list
,
new_computed_blocks
=
new_computed_block_list
,
)
)
)
if
num_blocks_to_allocate
>
self
.
block_pool
.
get_num_free_blocks
():
if
num_blocks_to_allocate
>
self
.
block_pool
.
get_num_free_blocks
():
# Cannot allocate new blocks
# Cannot allocate new blocks
...
@@ -253,16 +263,16 @@ class KVCacheManager:
...
@@ -253,16 +263,16 @@ class KVCacheManager:
if
self
.
enable_caching
:
if
self
.
enable_caching
:
self
.
block_pool
.
touch
(
new_computed_block_list
)
self
.
block_pool
.
touch
(
new_computed_block_list
)
else
:
else
:
assert
not
new_computed_block_list
,
(
assert
all
(
not
blocks
for
blocks
in
new_computed_block_list
)
,
(
"Computed blocks should be empty when "
"Computed blocks should be empty when "
"prefix caching is disabled"
)
"prefix caching is disabled"
)
# Append the new computed blocks to the request blocks until now to
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
# avoid the case where the new blocks cannot be allocated.
self
.
single_type_manage
r
.
save_new_computed_blocks
(
self
.
coordinato
r
.
save_new_computed_blocks
(
request
.
request_id
,
request
.
request_id
,
new_computed_block_list
)
new_computed_block_list
)
new_blocks
=
self
.
single_type_manage
r
.
allocate_new_blocks
(
new_blocks
=
self
.
coordinato
r
.
allocate_new_blocks
(
request
.
request_id
,
num_tokens_need_slot
)
request
.
request_id
,
num_tokens_need_slot
)
# P/D: delay caching blocks if we have to recv from
# P/D: delay caching blocks if we have to recv from
...
@@ -273,7 +283,7 @@ class KVCacheManager:
...
@@ -273,7 +283,7 @@ class KVCacheManager:
# Speculated tokens might be rejected in the future, so we does
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
# generated (accepted) tokens.
self
.
single_type_manage
r
.
cache_blocks
(
self
.
coordinato
r
.
cache_blocks
(
request
,
self
.
req_to_block_hashes
[
request
.
request_id
],
request
,
self
.
req_to_block_hashes
[
request
.
request_id
],
num_computed_tokens
+
num_new_tokens
-
num_draft_tokens
)
num_computed_tokens
+
num_new_tokens
-
num_draft_tokens
)
...
@@ -287,7 +297,7 @@ class KVCacheManager:
...
@@ -287,7 +297,7 @@ class KVCacheManager:
Args:
Args:
request: The request to free the blocks.
request: The request to free the blocks.
"""
"""
self
.
single_type_manage
r
.
free
(
request
.
request_id
)
self
.
coordinato
r
.
free
(
request
.
request_id
)
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
)
->
bool
:
"""Reset prefix cache. This function may be used in RLHF
"""Reset prefix cache. This function may be used in RLHF
...
@@ -345,10 +355,8 @@ class KVCacheManager:
...
@@ -345,10 +355,8 @@ class KVCacheManager:
group.
group.
"""
"""
assert
request
.
status
==
RequestStatus
.
RUNNING
assert
request
.
status
==
RequestStatus
.
RUNNING
return
[
return
self
.
coordinator
.
get_num_common_prefix_blocks
(
self
.
single_type_manager
.
get_num_common_prefix_blocks
(
request
.
request_id
,
num_running_requests
)
request
.
request_id
,
num_running_requests
)
]
def
free_block_hashes
(
self
,
request
:
Request
)
->
None
:
def
free_block_hashes
(
self
,
request
:
Request
)
->
None
:
"""Discard the block hashes for the request.
"""Discard the block hashes for the request.
...
@@ -368,6 +376,15 @@ class KVCacheManager:
...
@@ -368,6 +376,15 @@ class KVCacheManager:
def
get_block_ids
(
self
,
request_id
:
str
)
->
list
[
list
[
int
]]:
def
get_block_ids
(
self
,
request_id
:
str
)
->
list
[
list
[
int
]]:
"""Get the block ids of a request."""
"""Get the block ids of a request."""
assert
request_id
in
self
.
single_type_manager
.
req_to_blocks
return
KVCacheBlocks
(
return
KVCacheBlocks
(
self
.
single_type_manager
.
req_to_blocks
[
request_id
]
self
.
coordinator
.
get_blocks
(
request_id
)).
get_block_ids
()
).
get_block_ids
()
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
num_computed_tokens
:
int
)
->
None
:
"""Cache the blocks for the request."""
self
.
coordinator
.
cache_blocks
(
request
,
block_hashes
,
num_computed_tokens
)
def
create_empty_block_list
(
self
)
->
KVCacheBlocks
:
"""Creates a new KVCacheBlocks instance with no blocks."""
return
KVCacheBlocks
([[]
for
_
in
range
(
self
.
num_kv_cache_groups
)])
vllm/v1/core/kv_cache_utils.py
View file @
f8a1a2d1
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-Cache Utilities."""
"""KV-Cache Utilities."""
import
os
import
os
from
collections
import
deque
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Iterable
,
Sequence
from
collections.abc
import
Iterable
,
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
NamedTuple
,
Optional
from
typing
import
Any
,
Callable
,
NamedTuple
,
Optional
...
@@ -33,6 +34,18 @@ class BlockHash(NamedTuple):
...
@@ -33,6 +34,18 @@ class BlockHash(NamedTuple):
extra_keys
:
Optional
[
Any
]
=
None
extra_keys
:
Optional
[
Any
]
=
None
class
BlockHashWithGroupId
(
NamedTuple
):
# The hash value for the contents (e.g., token_ids) of a block without group
# ID. The value is the same for blocks representing the same tokens but for
# different groups.
block_hash
:
BlockHash
# The KV cache group ID.
group_id
:
int
def
get_hash_value
(
self
)
->
int
:
return
self
.
block_hash
.
hash_value
# The hash seed for the first block of the prefix block sequence.
# The hash seed for the first block of the prefix block sequence.
#
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
# Even if the hash function is the builtin hash(), we use sha256 to generate
...
@@ -44,7 +57,7 @@ class BlockHash(NamedTuple):
...
@@ -44,7 +57,7 @@ class BlockHash(NamedTuple):
# This aligns with the behavior of Python's hash() function, which also uses
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH
=
int
.
from_bytes
(
os
.
urandom
(
32
),
byteorder
=
"big"
)
if
os
.
getenv
(
NONE_HASH
=
int
.
from_bytes
(
os
.
urandom
(
32
),
byteorder
=
"big"
)
if
os
.
getenv
(
'
PYTHONHASHSEED
'
)
is
None
else
sha256
(
os
.
getenv
(
'
PYTHONHASHSEED
'
))
"
PYTHONHASHSEED
"
)
is
None
else
sha256
(
os
.
getenv
(
"
PYTHONHASHSEED
"
))
class
PrefixCachingMetrics
:
class
PrefixCachingMetrics
:
...
@@ -118,7 +131,7 @@ class KVCacheBlock:
...
@@ -118,7 +131,7 @@ class KVCacheBlock:
ref_cnt
:
int
=
0
ref_cnt
:
int
=
0
# The hash of the block composed of (block hash, tuple of token IDs).
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
# It is only available when the block is full.
_block_hash
:
Optional
[
BlockHash
]
=
None
_block_hash
:
Optional
[
BlockHash
WithGroupId
]
=
None
# Used to construct a doubly linked list for free blocks.
# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
...
@@ -135,11 +148,11 @@ class KVCacheBlock:
...
@@ -135,11 +148,11 @@ class KVCacheBlock:
self
.
ref_cnt
-=
1
self
.
ref_cnt
-=
1
@
property
@
property
def
block_hash
(
self
)
->
Optional
[
BlockHash
]:
def
block_hash
(
self
)
->
Optional
[
BlockHash
WithGroupId
]:
return
self
.
_block_hash
return
self
.
_block_hash
@
block_hash
.
setter
@
block_hash
.
setter
def
block_hash
(
self
,
block_hash
:
BlockHash
):
def
block_hash
(
self
,
block_hash
:
BlockHash
WithGroupId
):
assert
self
.
block_hash
is
None
,
(
assert
self
.
block_hash
is
None
,
(
"The block already has a hash. This should not happen."
)
"The block already has a hash. This should not happen."
)
self
.
_block_hash
=
block_hash
self
.
_block_hash
=
block_hash
...
@@ -151,10 +164,10 @@ class KVCacheBlock:
...
@@ -151,10 +164,10 @@ class KVCacheBlock:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
# Use block_id instead of KVCacheBlock object to avoid calling __repr__
# Use block_id instead of KVCacheBlock object to avoid calling __repr__
# on KVCacheBlock object recursively.
# on KVCacheBlock object recursively.
prev_block_id
=
self
.
prev_free_block
.
block_id
\
prev_block_id
=
(
self
.
prev_free_block
.
block_id
if
self
.
prev_free_block
else
None
if
self
.
prev_free_block
else
None
)
next_block_id
=
self
.
next_free_block
.
block_id
\
next_block_id
=
(
self
.
next_free_block
.
block_id
if
self
.
next_free_block
else
None
if
self
.
next_free_block
else
None
)
return
(
f
"KVCacheBlock(block_id=
{
self
.
block_id
}
, "
return
(
f
"KVCacheBlock(block_id=
{
self
.
block_id
}
, "
f
"ref_cnt=
{
self
.
ref_cnt
}
, "
f
"ref_cnt=
{
self
.
ref_cnt
}
, "
f
"_block_hash=
{
self
.
_block_hash
}
, "
f
"_block_hash=
{
self
.
_block_hash
}
, "
...
@@ -570,20 +583,20 @@ def create_kv_cache_group_specs(
...
@@ -570,20 +583,20 @@ def create_kv_cache_group_specs(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
grouped_layer_names
:
list
[
list
[
str
]])
->
list
[
KVCacheGroupSpec
]:
grouped_layer_names
:
list
[
list
[
str
]])
->
list
[
KVCacheGroupSpec
]:
"""
"""
Create KVCacheGroupSpec object for each kv cache group layer.
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
The layers in the same group should share the same
KVCacheSpec.
KVCacheSpec.
Args:
Args:
kv_cache_spec:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
names that belong to the same group and should share the same
KVCacheSpec.
KVCacheSpec.
Returns:
Returns:
A list of KVCacheGroupSpec objects, one for each group.
A list of KVCacheGroupSpec objects, one for each group.
"""
"""
kv_cache_groups
=
[]
kv_cache_groups
=
[]
for
layer_names_one_group
in
grouped_layer_names
:
for
layer_names_one_group
in
grouped_layer_names
:
layer_specs
=
[
layer_specs
=
[
...
@@ -628,6 +641,37 @@ def get_max_concurrency_for_kv_cache_config(
...
@@ -628,6 +641,37 @@ def get_max_concurrency_for_kv_cache_config(
return
max_concurrency
return
max_concurrency
def
get_num_blocks
(
vllm_config
:
VllmConfig
,
num_layers
:
int
,
available_memory
:
int
,
page_size
:
int
)
->
int
:
"""
Get the number of kv cache blocks.
Args:
vllm_config: The global VllmConfig
num_layers: The number of layers
available_memory: Memory available for KV cache in bytes.
page_size: The page size of the KV cache.
"""
num_blocks
=
int
(
available_memory
//
page_size
//
num_layers
)
num_blocks
=
max
(
num_blocks
,
0
)
if
vllm_config
.
cache_config
.
num_gpu_blocks_override
is
not
None
:
num_gpu_blocks_override
=
\
vllm_config
.
cache_config
.
num_gpu_blocks_override
logger
.
info
(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d"
,
num_blocks
,
num_gpu_blocks_override
)
return
num_blocks
def
get_uniform_page_size
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
])
->
int
:
"""
Get the page size of the KV cache.
"""
page_sizes
=
set
(
layer
.
page_size_bytes
for
layer
in
kv_cache_spec
.
values
())
assert
len
(
page_sizes
)
==
1
return
page_sizes
.
pop
()
def
_get_kv_cache_config_uniform_type
(
vllm_config
:
VllmConfig
,
def
_get_kv_cache_config_uniform_type
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
KVCacheConfig
:
available_memory
:
int
)
->
KVCacheConfig
:
...
@@ -644,32 +688,24 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
...
@@ -644,32 +688,24 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
The generated KVCacheConfig
The generated KVCacheConfig
"""
"""
page_sizes
=
{
layer
.
page_size_bytes
for
layer
in
kv_cache_spec
.
values
()}
page_size
=
get_uniform_page_size
(
kv_cache_spec
)
assert
len
(
page_sizes
)
==
1
num_blocks
=
get_num_blocks
(
vllm_config
,
len
(
kv_cache_spec
),
page_size
=
page_sizes
.
pop
()
available_memory
,
page_size
)
num_blocks
=
int
(
available_memory
//
page_size
//
len
(
kv_cache_spec
))
num_blocks
=
max
(
num_blocks
,
0
)
if
vllm_config
.
cache_config
.
num_gpu_blocks_override
is
not
None
:
num_gpu_blocks_override
=
\
vllm_config
.
cache_config
.
num_gpu_blocks_override
logger
.
info
(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d"
,
num_blocks
,
num_gpu_blocks_override
)
num_blocks
=
num_gpu_blocks_override
per_layer_size
=
page_size
*
num_blocks
per_layer_size
=
page_size
*
num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
# for all layers.
grouped_layer_names
=
[
list
(
kv_cache_spec
.
keys
())]
grouped_layer_names
=
[
list
(
kv_cache_spec
.
keys
())]
# Each layer uses a separate Tensor to store its KV cache.
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
per_layer_size
,
shared_by
=
[
layer_name
])
for
layer_name
in
kv_cache_spec
]
kv_cache_config
=
KVCacheConfig
(
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
num_blocks
=
num_blocks
,
tensors
=
{
kv_cache_tensors
=
kv_cache_tensors
,
layer_name
:
KVCacheTensor
(
size
=
per_layer_size
)
for
layer_name
in
kv_cache_spec
},
kv_cache_groups
=
create_kv_cache_group_specs
(
kv_cache_spec
,
kv_cache_groups
=
create_kv_cache_group_specs
(
kv_cache_spec
,
grouped_layer_names
),
grouped_layer_names
),
)
)
...
@@ -685,17 +721,185 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
...
@@ -685,17 +721,185 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return
kv_cache_config
return
kv_cache_config
def
is_kv_cache_page_size_uniform
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
])
->
bool
:
"""
Whether all layers in the given KVCacheSpec have the same page size.
Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
True if all layers have the same page size, False otherwise.
"""
page_sizes
=
{
layer
.
page_size_bytes
for
layer
in
kv_cache_spec
.
values
()}
return
len
(
page_sizes
)
==
1
def
_get_kv_cache_config_uniform_page_size
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
KVCacheConfig
:
"""
Generates the KV cache configuration for hybrid models with multiple
attention types but still with a uniform page size (physical memory per
block per layer) for all layers.
Detailed explanation about kv cache management of hybrid models:
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 kv_cache_groups, each
of which contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers. It is already handled by
`_get_kv_cache_config_uniform_type`.
2. A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 kv_cache_groups, each of which represents 10 layers.
To simplify the implementation, we make the following assumptions:
1. Physical memory per block: Must be the same across all KV cache groups.
Breaking this assumption is non-trivial due to memory fragmentation concerns
when allocating blocks of different sizes.
2. Tokens per block (block_size): Currently, we directly use
`CacheConfig.block_size` for all layers. It can be extended to vary by KV
cache group, but within each KV cache group, all layers must share the same
block size.
3. Physical memory per token per layer: This property is decided by model
config. Currently we only support models that have the same physical memory
per token per layer for all layers. Can be relaxed with a simple extension,
but still need to keep physical memory per block the same for all groups.
4. Number of layers per group: Currently assumed the same for all layers.
Can be relaxed with a simple extension, but still need to keep physical
memory per block the same for all groups.
5. Attention type within groups: All layers in a group must share the same
attention type. One exception is that, when
`--disable-hybrid-kv-cache-manager` is true, the single group for full
attention layers may also include attention layers using sliding window or
LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details.
6. Support for multiple attention types: The design for most components is
general to an arbitrary number of attention types. But
`find_longest_cache_hit` only supports one attention type or two
types of full-attention plus exactly one another type. The general
implementation of this function is feasible but we don't know how to
implement it cleanly yet.
As we assume tokens per block, physical memory per token per layer, and
number of layers per group are the same now, we can ensure that physical
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
# Group all layers by type_id.
# E.g., 2 full attention layers and 3 sliding window attention layers,
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
same_type_layers
:
dict
[
str
,
list
[
str
]]
=
defaultdict
(
list
)
for
layer_name
,
layer_spec
in
kv_cache_spec
.
items
():
same_type_layers
[
layer_spec
.
type_id
].
append
(
layer_name
)
# Split each group into smaller groups, to make the number of layers in each
# group identical. Add padding to the last group of each type if necessary.
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
# split to 3 groups with 2 layers each:
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
# open-source hybrid model follows a n:1 pattern between different attention
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
# full), so we can use the "1" in the n:1 pattern as the group size, which
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
group_size
=
min
([
len
(
layers
)
for
layers
in
same_type_layers
.
values
()])
grouped_layers
=
[]
for
layers
in
same_type_layers
.
values
():
num_padding_layers
=
group_size
-
len
(
layers
)
%
group_size
if
num_padding_layers
!=
group_size
:
logger
.
warning
(
"Add %d padding layers, may waste at most %.2f%% KV cache memory"
,
# noqa
num_padding_layers
,
num_padding_layers
/
len
(
layers
)
*
100
,
)
for
i
in
range
(
0
,
len
(
layers
),
group_size
):
grouped_layers
.
append
(
layers
[
i
:
i
+
group_size
])
kv_cache_groups
=
create_kv_cache_group_specs
(
kv_cache_spec
,
grouped_layers
)
# Determine how model runners should initialize the KV cache tensors.
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout in the example will be:
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
# full.1, sw.1: share another Tensor with size=available_memory//2
page_size
=
get_uniform_page_size
(
kv_cache_spec
)
num_blocks
=
get_num_blocks
(
vllm_config
,
group_size
,
available_memory
,
page_size
)
per_memory_pool_size
=
page_size
*
num_blocks
kv_cache_tensors
=
[]
for
i
in
range
(
group_size
):
shared_by
=
[]
for
j
in
range
(
len
(
kv_cache_groups
)):
if
i
<
len
(
grouped_layers
[
j
]):
shared_by
.
append
(
grouped_layers
[
j
][
i
])
kv_cache_tensors
.
append
(
KVCacheTensor
(
size
=
per_memory_pool_size
,
shared_by
=
shared_by
))
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
kv_cache_tensors
=
kv_cache_tensors
,
kv_cache_groups
=
kv_cache_groups
,
)
# Print the KV cache size and maximum concurrency.
num_tokens
=
num_blocks
//
len
(
grouped_layers
)
*
vllm_config
.
cache_config
.
block_size
num_tokens_str
=
f
"
{
num_tokens
:,
}
"
logger
.
info
(
"GPU KV cache size: %s tokens"
,
num_tokens_str
)
max_model_len_str
=
f
"
{
vllm_config
.
model_config
.
max_model_len
:,
}
"
max_concurrency
=
get_max_concurrency_for_kv_cache_config
(
vllm_config
,
kv_cache_config
)
logger
.
info
(
"Maximum concurrency for %s tokens per request: %.2fx"
,
max_model_len_str
,
max_concurrency
)
return
kv_cache_config
def
unify_hybrid_kv_cache_specs
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]):
def
unify_hybrid_kv_cache_specs
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]):
"""
"""
Only models with one type of KV cache are supported yet. This function tries
This function tries to convert the KV cache specs to one type if the model
to convert the KV cache specs to one type if the model is a hybrid model
is a hybrid model with multiple type of KV cache. It will convert all
with multiple type of KV cache. It will convert all SlidingWindowSpec to
SlidingWindowSpec to FullAttentionSpec if both types are present.
FullAttentionSpec if both types are present.
Args:
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
"""
def
is_hybrid
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
])
->
bool
:
type_ids
=
set
(
layer_spec
.
type_id
for
layer_spec
in
kv_cache_spec
.
values
())
return
len
(
type_ids
)
>
1
if
not
is_hybrid
(
kv_cache_spec
):
return
logger
.
warning
(
"Hybrid KV cache manager is disabled for this hybrid model, "
"This means we do not enable any optimizations for saving KV cache "
"memory (e.g., dropping the KV cache outside the sliding window). "
"The compute of layers like sliding window is still saved."
)
has_full_attention
=
any
(
has_full_attention
=
any
(
isinstance
(
spec
,
FullAttentionSpec
)
for
spec
in
kv_cache_spec
.
values
())
isinstance
(
spec
,
FullAttentionSpec
)
for
spec
in
kv_cache_spec
.
values
())
has_sliding_window
=
any
(
has_sliding_window
=
any
(
...
@@ -712,13 +916,18 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
...
@@ -712,13 +916,18 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
sliding_window
=
spec
.
sliding_window
,
sliding_window
=
spec
.
sliding_window
,
)
)
if
is_hybrid
(
kv_cache_spec
):
raise
ValueError
(
"Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type."
)
def
get_kv_cache_config
(
vllm_config
:
VllmConfig
,
def
get_kv_cache_config
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
vllm_config
:
VllmConfig
,
available_memory
:
int
)
->
KVCacheConfig
:
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
,
)
->
KVCacheConfig
:
"""
"""
Generates the KV cache configuration for a model
Generates the KV cache configuration for a model.
TODO: support hybrid models with more than one type of KV cache.
Args:
Args:
vllm_config: The global VllmConfig
vllm_config: The global VllmConfig
...
@@ -728,14 +937,25 @@ def get_kv_cache_config(vllm_config: VllmConfig,
...
@@ -728,14 +937,25 @@ def get_kv_cache_config(vllm_config: VllmConfig,
Returns:
Returns:
The generated KVCacheConfigs
The generated KVCacheConfigs
"""
"""
unify_hybrid_kv_cache_specs
(
kv_cache_spec
)
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
if
vllm_config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
:
unify_hybrid_kv_cache_specs
(
kv_cache_spec
)
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
# KV cache of all layers are the same, which is true for
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# most models. Allocate the same amount of memory for
# each layer.
# each layer.
return
_get_kv_cache_config_uniform_type
(
vllm_config
,
kv_cache_spec
,
return
_get_kv_cache_config_uniform_type
(
vllm_config
,
kv_cache_spec
,
available_memory
)
available_memory
)
elif
is_kv_cache_page_size_uniform
(
kv_cache_spec
):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return
_get_kv_cache_config_uniform_page_size
(
vllm_config
,
kv_cache_spec
,
available_memory
)
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/v1/core/sched/scheduler.py
View file @
f8a1a2d1
...
@@ -18,7 +18,7 @@ from vllm.logger import init_logger
...
@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
compute_encoder_budget
)
compute_encoder_budget
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
,
KVCacheManager
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
SchedulerOutput
)
...
@@ -377,7 +377,8 @@ class Scheduler(SchedulerInterface):
...
@@ -377,7 +377,8 @@ class Scheduler(SchedulerInterface):
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
# after async KV recvs are completed.
else
:
else
:
new_computed_blocks
=
KVCacheBlocks
.
create_empty
()
new_computed_blocks
=
(
self
.
kv_cache_manager
.
create_empty_block_list
())
num_new_local_computed_tokens
=
0
num_new_local_computed_tokens
=
0
num_computed_tokens
=
request
.
num_computed_tokens
num_computed_tokens
=
request
.
num_computed_tokens
...
@@ -1010,7 +1011,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1010,7 +1011,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens
=
len
(
block_ids
)
*
self
.
block_size
num_computed_tokens
=
len
(
block_ids
)
*
self
.
block_size
if
num_computed_tokens
==
request
.
num_tokens
:
if
num_computed_tokens
==
request
.
num_tokens
:
num_computed_tokens
-=
1
num_computed_tokens
-=
1
self
.
kv_cache_manager
.
single_type_manager
.
cache_blocks
(
self
.
kv_cache_manager
.
cache_blocks
(
request
,
request
,
self
.
kv_cache_manager
.
req_to_block_hashes
[
request
.
request_id
],
self
.
kv_cache_manager
.
req_to_block_hashes
[
request
.
request_id
],
num_computed_tokens
,
num_computed_tokens
,
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
f8a1a2d1
...
@@ -22,8 +22,7 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -22,8 +22,7 @@ class SingleTypeKVCacheManager(ABC):
self
,
self
,
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
KVCacheSpec
,
block_pool
:
BlockPool
,
block_pool
:
BlockPool
,
use_eagle
:
bool
,
kv_cache_group_id
:
int
,
num_kv_cache_groups
:
int
,
caching_hash_fn
:
Callable
,
caching_hash_fn
:
Callable
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -31,9 +30,7 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -31,9 +30,7 @@ class SingleTypeKVCacheManager(ABC):
Args:
Args:
kv_cache_spec: The kv_cache_spec for this manager.
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
block_pool: The block pool.
use_eagle: Whether to use eagle.
kv_cache_group_id: The id of the kv cache group of this manager.
num_kv_cache_groups: The number of kv cache groups managed by this
manager.
caching_hash_fn: The caching hash function.
caching_hash_fn: The caching hash function.
"""
"""
...
@@ -41,9 +38,6 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -41,9 +38,6 @@ class SingleTypeKVCacheManager(ABC):
self
.
kv_cache_spec
=
kv_cache_spec
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_pool
=
block_pool
self
.
block_pool
=
block_pool
# Needs special handling for find_longest_cache_hit if eagle is enabled
self
.
use_eagle
=
use_eagle
# Mapping from request ID to blocks to track the blocks allocated
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# for each request, so that we can free the blocks when the request
# is finished.
# is finished.
...
@@ -56,8 +50,8 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -56,8 +50,8 @@ class SingleTypeKVCacheManager(ABC):
# data for reempted ones.
# data for reempted ones.
self
.
num_cached_block
:
dict
[
str
,
int
]
=
{}
self
.
num_cached_block
:
dict
[
str
,
int
]
=
{}
self
.
num_kv_cache_groups
=
num_kv_cache_groups
self
.
caching_hash_fn
=
caching_hash_fn
self
.
caching_hash_fn
=
caching_hash_fn
self
.
kv_cache_group_id
=
kv_cache_group_id
def
get_num_blocks_to_allocate
(
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
self
,
request_id
:
str
,
num_tokens
:
int
,
...
@@ -86,8 +80,7 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -86,8 +80,7 @@ class SingleTypeKVCacheManager(ABC):
num_evictable_computed_blocks
=
sum
(
num_evictable_computed_blocks
=
sum
(
blk
.
ref_cnt
==
0
and
not
blk
.
is_null
blk
.
ref_cnt
==
0
and
not
blk
.
is_null
for
blk
in
new_computed_blocks
)
for
blk
in
new_computed_blocks
)
return
((
num_new_blocks
+
num_evictable_computed_blocks
)
*
return
num_new_blocks
+
num_evictable_computed_blocks
self
.
num_kv_cache_groups
)
def
save_new_computed_blocks
(
def
save_new_computed_blocks
(
self
,
request_id
:
str
,
self
,
request_id
:
str
,
...
@@ -130,8 +123,7 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -130,8 +123,7 @@ class SingleTypeKVCacheManager(ABC):
if
num_new_blocks
<=
0
:
if
num_new_blocks
<=
0
:
return
[]
return
[]
else
:
else
:
new_blocks
=
self
.
block_pool
.
get_new_blocks
(
new_blocks
=
self
.
block_pool
.
get_new_blocks
(
num_new_blocks
)
num_new_blocks
*
self
.
num_kv_cache_groups
)
req_blocks
.
extend
(
new_blocks
)
req_blocks
.
extend
(
new_blocks
)
return
new_blocks
return
new_blocks
...
@@ -156,12 +148,19 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -156,12 +148,19 @@ class SingleTypeKVCacheManager(ABC):
num_cached_blocks
=
num_cached_blocks
,
num_cached_blocks
=
num_cached_blocks
,
num_full_blocks
=
num_full_blocks
,
num_full_blocks
=
num_full_blocks
,
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
kv_cache_group_id
=
self
.
kv_cache_group_id
,
hash_fn
=
self
.
caching_hash_fn
,
hash_fn
=
self
.
caching_hash_fn
,
)
)
self
.
num_cached_block
[
request
.
request_id
]
=
num_full_blocks
self
.
num_cached_block
[
request
.
request_id
]
=
num_full_blocks
def
free
(
self
,
request_id
:
str
)
->
None
:
def
free
(
self
,
request_id
:
str
)
->
None
:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
# Default to [] in case a request is freed (aborted) before alloc.
# Default to [] in case a request is freed (aborted) before alloc.
req_blocks
=
self
.
req_to_blocks
.
pop
(
request_id
,
[])
req_blocks
=
self
.
req_to_blocks
.
pop
(
request_id
,
[])
...
@@ -188,12 +187,22 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -188,12 +187,22 @@ class SingleTypeKVCacheManager(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
@
abstractmethod
@
abstractmethod
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
def
find_longest_cache_hit
(
max_length
:
int
)
->
list
[
KVCacheBlock
]:
cls
,
block_hashes
:
list
[
BlockHash
],
max_length
:
int
,
kv_cache_group_ids
:
list
[
int
],
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
"""
"""
Get the longest cache hit prefix of the blocks that is not longer than
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. If no cache hit is found, return an empty list.
`max_length`. The prefix should be a common prefix hit for all the
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
return an empty list.
If eagle is enabled, drop the last matched block to force recompute the
If eagle is enabled, drop the last matched block to force recompute the
last block to get the required hidden states for eagle drafting head.
last block to get the required hidden states for eagle drafting head.
Need to be customized for each attention type.
Need to be customized for each attention type.
...
@@ -201,12 +210,20 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -201,12 +210,20 @@ class SingleTypeKVCacheManager(ABC):
Args:
Args:
block_hashes: The block hashes of the request.
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
Returns:
Returns:
A list of cached blocks with skipped blocks replaced by null block.
A list of cached blocks with skipped blocks replaced by null block
for each kv cache group in `kv_cache_group_ids`.
Return a list of length `len(kv_cache_group_ids)`, where the i-th
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4
and
[
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]
]
for block size 4
sliding window 8
.
and
sliding window 8
and len(kv_cache_group_ids) = 1.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -215,11 +232,9 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -215,11 +232,9 @@ class SingleTypeKVCacheManager(ABC):
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
num_computed_tokens
:
int
)
->
None
:
num_computed_tokens
:
int
)
->
None
:
"""
"""
Remove the blocks that are no longer needed from `blocks`. The removed
Remove the blocks that are no longer needed from `blocks` and free the
blocks should be replaced by null_block. Return the removed blocks in
blocks. The removed blocks should be replaced by null_block.
eviction order, where the first returned block should be evicted first.
Need to be customized for each attention type.
Don't free the removed blocks in this function. Need to be customized
for each attention type.
Args:
Args:
request_id: The request ID.
request_id: The request ID.
...
@@ -230,21 +245,36 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -230,21 +245,36 @@ class SingleTypeKVCacheManager(ABC):
class
FullAttentionManager
(
SingleTypeKVCacheManager
):
class
FullAttentionManager
(
SingleTypeKVCacheManager
):
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
@
classmethod
max_length
:
int
)
->
list
[
KVCacheBlock
]:
def
find_longest_cache_hit
(
computed_blocks
:
list
[
KVCacheBlock
]
=
[]
cls
,
max_num_blocks
=
max_length
//
self
.
block_size
block_hashes
:
list
[
BlockHash
],
max_length
:
int
,
kv_cache_group_ids
:
list
[
int
],
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
assert
isinstance
(
kv_cache_spec
,
FullAttentionSpec
),
(
"FullAttentionManager can only be used for full attention groups"
)
computed_blocks
:
list
[
list
[
KVCacheBlock
]]
=
[
[]
for
_
in
range
(
len
(
kv_cache_group_ids
))
]
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
for
i
in
range
(
max_num_blocks
):
for
i
in
range
(
max_num_blocks
):
block_hash
=
block_hashes
[
i
]
block_hash
=
block_hashes
[
i
]
# block_hashes is a chain of block hashes. If a block hash is not
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
# not computed yet for sure.
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
block_hash
):
if
cached_block
:
=
block_pool
.
get_cached_block
(
computed_blocks
.
append
(
cached_block
)
block_hash
,
kv_cache_group_ids
):
for
j
in
range
(
len
(
kv_cache_group_ids
)):
computed_blocks
[
j
].
append
(
cached_block
[
j
])
else
:
else
:
break
break
if
self
.
use_eagle
and
len
(
computed_blocks
)
>
0
:
if
use_eagle
and
len
(
computed_blocks
[
0
])
>
0
:
computed_blocks
.
pop
()
for
j
in
range
(
len
(
kv_cache_group_ids
)):
computed_blocks
[
j
].
pop
()
return
computed_blocks
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
...
@@ -267,45 +297,58 @@ class FullAttentionManager(SingleTypeKVCacheManager):
...
@@ -267,45 +297,58 @@ class FullAttentionManager(SingleTypeKVCacheManager):
class
SlidingWindowManager
(
SingleTypeKVCacheManager
):
class
SlidingWindowManager
(
SingleTypeKVCacheManager
):
def
__init__
(
self
,
kv_cache_spec
:
SlidingWindowSpec
,
block_pool
:
BlockPool
,
def
__init__
(
self
,
kv_cache_spec
:
SlidingWindowSpec
,
block_pool
:
BlockPool
,
use_eagle
:
bool
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
super
().
__init__
(
kv_cache_spec
,
block_pool
,
use_eagle
,
**
kwargs
)
super
().
__init__
(
kv_cache_spec
,
block_pool
,
**
kwargs
)
self
.
sliding_window
=
kv_cache_spec
.
sliding_window
self
.
sliding_window
=
kv_cache_spec
.
sliding_window
self
.
_null_block
=
block_pool
.
null_block
@
classmethod
def
find_longest_cache_hit
(
cls
,
block_hashes
:
list
[
BlockHash
],
max_length
:
int
,
kv_cache_group_ids
:
list
[
int
],
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
list
[
list
[
KVCacheBlock
]]:
assert
isinstance
(
kv_cache_spec
,
SlidingWindowSpec
),
(
"SlidingWindowManager can only be used for sliding window groups"
)
# The number of contiguous blocks needed for prefix cache hit.
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
# -1 since the input token itself is also included in the window
self
.
sliding_window_contiguous_blocks
=
cdiv
(
sliding_window_contiguous_blocks
=
cdiv
(
(
kv_cache_spec
.
sliding_window
-
1
)
,
self
.
block_size
)
kv_cache_spec
.
sliding_window
-
1
,
kv_cache_spec
.
block_size
)
if
self
.
use_eagle
:
if
use_eagle
:
# Need to drop the last matched block if eagle is enabled. For
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
# the last matched block.
self
.
sliding_window_contiguous_blocks
+=
1
sliding_window_contiguous_blocks
+=
1
self
.
_null_block
=
block_pool
.
null_block
def
find_longest_cache_hit
(
self
,
block_hashes
:
list
[
BlockHash
],
max_length
:
int
)
->
list
[
KVCacheBlock
]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(max_num_blocks) to
# optimize the time complexity from O(max_num_blocks) to
# O(max_num_blocks / sliding_window_contiguous_blocks +
# O(max_num_blocks / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
# which is good for low cache hit rate scenarios.
max_num_blocks
=
max_length
//
self
.
block_size
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
computed_blocks
=
[
self
.
_null_block
]
*
max_num_blocks
computed_blocks
=
[[
block_pool
.
null_block
]
*
max_num_blocks
for
_
in
range
(
len
(
kv_cache_group_ids
))]
num_contiguous_blocks
=
0
num_contiguous_blocks
=
0
match_found
=
False
match_found
=
False
# Search from right to left and early stop when a match is found.
# Search from right to left and early stop when a match is found.
for
i
in
range
(
max_num_blocks
-
1
,
-
1
,
-
1
):
for
i
in
range
(
max_num_blocks
-
1
,
-
1
,
-
1
):
if
cached_block
:
=
self
.
block_pool
.
get_cached_block
(
if
cached_block
:
=
block_pool
.
get_cached_block
(
block_hashes
[
i
]):
block_hashes
[
i
],
kv_cache_group_ids
):
computed_blocks
[
i
]
=
cached_block
for
j
in
range
(
len
(
kv_cache_group_ids
)):
computed_blocks
[
j
][
i
]
=
cached_block
[
j
]
num_contiguous_blocks
+=
1
num_contiguous_blocks
+=
1
if
(
num_contiguous_blocks
if
(
num_contiguous_blocks
>=
sliding_window_contiguous_blocks
):
>=
self
.
sliding_window_contiguous_blocks
):
# Trim the trailing blocks.
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
# when sliding_window_contiguous_blocks=2.
del
computed_blocks
[
i
+
num_contiguous_blocks
:]
for
j
in
range
(
len
(
kv_cache_group_ids
)):
del
computed_blocks
[
j
][
i
+
num_contiguous_blocks
:]
match_found
=
True
match_found
=
True
break
break
else
:
else
:
...
@@ -313,9 +356,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
...
@@ -313,9 +356,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if
not
match_found
:
if
not
match_found
:
# The first `num_contiguous_blocks` is a cache hit even if
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del
computed_blocks
[
num_contiguous_blocks
:]
for
j
in
range
(
len
(
kv_cache_group_ids
)):
if
self
.
use_eagle
and
len
(
computed_blocks
)
>
0
:
del
computed_blocks
[
j
][
num_contiguous_blocks
:]
computed_blocks
.
pop
()
if
use_eagle
and
len
(
computed_blocks
[
0
])
>
0
:
for
j
in
range
(
len
(
kv_cache_group_ids
)):
computed_blocks
[
j
].
pop
()
return
computed_blocks
return
computed_blocks
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
...
...
vllm/v1/kv_cache_interface.py
View file @
f8a1a2d1
...
@@ -157,11 +157,10 @@ class SlidingWindowSpec(AttentionSpec):
...
@@ -157,11 +157,10 @@ class SlidingWindowSpec(AttentionSpec):
@
dataclass
@
dataclass
class
KVCacheTensor
:
class
KVCacheTensor
:
"""
"""
A dataclass for specifying how the workers should initialize the KV cache
A class for specifying how the workers should initialize the KV cache.
for a layer. Only contains the size of KV cache for that layer for now. Will
be extended to support multiple layers sharing the same memory pool.
"""
"""
size
:
int
# The size of KV cache Tensor in bytes
size
:
int
# size of the KV cache tensor in bytes
shared_by
:
list
[
str
]
# layer names that share the same KV cache tensor
@
dataclass
@
dataclass
...
@@ -183,27 +182,13 @@ class KVCacheConfig:
...
@@ -183,27 +182,13 @@ class KVCacheConfig:
"""
"""
"""The number of KV cache blocks"""
"""The number of KV cache blocks"""
num_blocks
:
int
num_blocks
:
int
"""
layer_name -> how to
initialize KV cache
for that
layer"""
"""
How should model runner
initialize
the
KV cache
tensors for each
layer"""
tensors
:
dict
[
str
,
KVCacheTensor
]
kv_cache_tensors
:
list
[
KVCacheTensor
]
"""
"""
The kv cache groups of the model.
The kv cache groups of the model.
The layers in the models are repeated with some patterns, e.g., a model
For models with only one type of attention, there is only one group that
with 10 full attention layers and 20 sliding window attention layers can be
contains all layers.
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
For models with multiple types of attention, there will be multiple groups,
The KVCacheManager allocates different block tables for each of the 3 layers
see `_get_kv_cache_config_uniform_page_size` for more details.
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 groups, each of which
contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers.
2. (WIP) A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 groups, each of which represents 10 layers in the model.
"""
"""
kv_cache_groups
:
list
[
KVCacheGroupSpec
]
kv_cache_groups
:
list
[
KVCacheGroupSpec
]
vllm/v1/worker/gpu_model_runner.py
View file @
f8a1a2d1
...
@@ -2088,33 +2088,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2088,33 +2088,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_sizes
=
block_sizes
,
block_sizes
=
block_sizes
,
)
)
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
_allocate_kv_cache_tensors
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
dict
[
str
,
torch
.
Tensor
]:
"""
"""
Initialize KV cache based on `kv_cache_config`.
Initializes the KV cache buffer with the correct size. The buffer needs
to be reshaped to the desired shape before being used by the models.
Args:
Args:
kv_cache_config: Configuration for the KV cache, including the KV
kv_cache_config: The KV cache config
cache size of each layer
Returns:
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
kv_cache_tensor
in
kv_cache_config
.
kv_cache_tensors
:
tensor
=
torch
.
zeros
(
kv_cache_tensor
.
size
,
dtype
=
torch
.
int8
,
device
=
self
.
device
)
for
layer_name
in
kv_cache_tensor
.
shared_by
:
kv_cache_raw_tensors
[
layer_name
]
=
tensor
layer_names
=
set
()
for
group
in
kv_cache_config
.
kv_cache_groups
:
layer_names
.
update
(
group
.
layer_names
)
assert
layer_names
==
set
(
kv_cache_raw_tensors
.
keys
(
)),
"Some layers are not correctly initialized"
return
kv_cache_raw_tensors
def
_reshape_kv_cache_tensors
(
self
,
kv_cache_config
:
KVCacheConfig
,
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
],
)
->
dict
[
str
,
torch
.
Tensor
]:
"""
"""
self
.
kv_cache_config
=
kv_cache_config
Reshape the KV cache tensors to the desired shape and dtype.
self
.
may_reinitialize_input_batch
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
)
Args:
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
i
,
kv_cache_group_spec
in
enumerate
(
for
i
,
kv_cache_group
in
enumerate
(
kv_cache_config
.
kv_cache_groups
):
kv_cache_config
.
kv_cache_groups
):
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
for
layer_name
in
kv_cache_group
.
layer_names
:
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
raw_tensor
=
kv_cache_raw_tensors
[
layer_name
]
assert
tensor_config
.
size
%
kv_cache_spec
.
page_size_bytes
==
0
assert
raw_tensor
.
numel
()
%
kv_cache_spec
.
page_size_bytes
==
0
num_blocks
=
tensor_config
.
size
//
kv_cache_spec
.
page_size_bytes
num_blocks
=
(
raw_tensor
.
numel
()
//
# `num_blocks` is the number of blocks the model runner can use.
kv_cache_spec
.
page_size_bytes
)
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert
num_blocks
>=
kv_cache_config
.
num_blocks
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
kv_cache_shape
=
self
.
attn_backends
[
i
].
get_kv_cache_shape
(
kv_cache_shape
=
self
.
attn_backends
[
i
].
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
num_blocks
,
kv_cache_spec
.
block_size
,
...
@@ -2140,13 +2165,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2140,13 +2165,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_stride_order
.
index
(
i
)
kv_cache_stride_order
.
index
(
i
)
for
i
in
range
(
len
(
kv_cache_stride_order
))
for
i
in
range
(
len
(
kv_cache_stride_order
))
]
]
kv_caches
[
layer_name
]
=
torch
.
zeros
(
kv_caches
[
layer_name
]
=
kv_cache_raw_tensors
[
kv_cache_shape
,
dtype
=
dtype
,
layer_name
].
view
(
dtype
).
view
(
kv_cache_shape
).
permute
(
device
=
self
.
device
).
permute
(
*
inv_order
)
*
inv_order
)
else
:
else
:
# TODO: add new branches when introducing more types of
raise
NotImplementedError
# KV cache specs.
return
kv_caches
raise
ValueError
(
"Unknown KV cache spec type."
)
def
initialize_kv_cache_tensors
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
dict
[
str
,
torch
.
Tensor
]:
"""
Initialize the memory buffer for KV cache.
Args:
kv_cache_config: The KV cache config
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors
=
self
.
_allocate_kv_cache_tensors
(
kv_cache_config
)
# Change the memory buffer to the desired shape
kv_caches
=
self
.
_reshape_kv_cache_tensors
(
kv_cache_config
,
kv_cache_raw_tensors
)
# Setup `kv_cache_config` and `kv_caches` for models
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
# with cross-layer KV sharing
...
@@ -2157,17 +2198,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2157,17 +2198,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches
,
kv_caches
,
)
)
bind_kv_cache
(
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
return
kv_caches
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
self
.
kv_cache_config
=
kv_cache_config
self
.
may_reinitialize_input_batch
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
)
kv_caches
=
self
.
initialize_kv_cache_tensors
(
kv_cache_config
)
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# validate all draft model layers belong to the same kv cache
# validate all draft model layers belong to the same kv cache
# group
# group
self
.
drafter
.
validate_same_kv_cache_group
(
kv_cache_config
)
self
.
drafter
.
validate_same_kv_cache_group
(
kv_cache_config
)
bind_kv_cache
(
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
register_kv_caches
(
kv_caches
)
get_kv_transfer_group
().
register_kv_caches
(
kv_caches
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment