Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
71df2a57
Unverified
Commit
71df2a57
authored
Nov 24, 2025
by
Chen Zhang
Committed by
GitHub
Nov 24, 2025
Browse files
[Hybrid Allocator] Better layer padding strategy for gpt-oss eagle (#29303)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
4dd42db5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
1 deletion
+69
-1
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+59
-0
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+10
-1
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
71df2a57
...
@@ -1436,6 +1436,65 @@ def test_get_kv_cache_config_one_worker():
...
@@ -1436,6 +1436,65 @@ def test_get_kv_cache_config_one_worker():
],
],
)
)
# 6 full + 5 sliding, pad to 6 full + 6 sliding. This is a typical case for gpt-oss
# eagle where there is only one more full attention layer than sliding window layers
kv_cache_specs_hybrid
=
{
"layer_1"
:
new_kv_cache_spec
(),
"layer_2"
:
new_kv_cache_spec
(),
"layer_3"
:
new_kv_cache_spec
(),
"layer_4"
:
new_kv_cache_spec
(),
"layer_5"
:
new_kv_cache_spec
(),
"layer_6"
:
new_kv_cache_spec
(),
"layer_7"
:
new_sliding_window_spec
(),
"layer_8"
:
new_sliding_window_spec
(),
"layer_9"
:
new_sliding_window_spec
(),
"layer_10"
:
new_sliding_window_spec
(),
"layer_11"
:
new_sliding_window_spec
(),
}
kv_cache_config_hybrid
=
get_kv_cache_configs
(
vllm_config
,
[
kv_cache_specs_hybrid
],
[
mem_per_block_per_layer
*
6
*
32
]
)[
0
]
print
(
kv_cache_config_hybrid
)
assert
kv_cache_config_hybrid
==
KVCacheConfig
(
num_blocks
=
32
,
kv_cache_tensors
=
[
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_1"
,
"layer_7"
],
),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_2"
,
"layer_8"
],
),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_3"
,
"layer_9"
],
),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_4"
,
"layer_10"
],
),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_5"
,
"layer_11"
],
),
KVCacheTensor
(
size
=
mem_per_block_per_layer
*
32
,
shared_by
=
[
"layer_6"
],
),
],
kv_cache_groups
=
[
KVCacheGroupSpec
(
[
"layer_1"
,
"layer_2"
,
"layer_3"
,
"layer_4"
,
"layer_5"
,
"layer_6"
],
new_kv_cache_spec
(),
),
KVCacheGroupSpec
(
[
"layer_7"
,
"layer_8"
,
"layer_9"
,
"layer_10"
,
"layer_11"
],
new_sliding_window_spec
(),
),
],
)
# different hidden size
# different hidden size
kv_cache_specs_hybrid
=
{
kv_cache_specs_hybrid
=
{
"layer_1"
:
new_kv_cache_spec
(
head_size
=
128
),
"layer_1"
:
new_kv_cache_spec
(
head_size
=
128
),
...
...
vllm/v1/core/kv_cache_utils.py
View file @
71df2a57
...
@@ -971,7 +971,16 @@ def _get_kv_cache_groups_uniform_page_size(
...
@@ -971,7 +971,16 @@ def _get_kv_cache_groups_uniform_page_size(
# is the minimum number of layers among all attention types. Need a better
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
# sw, where the group size should be 10).
group_size
=
min
([
len
(
layers
)
for
layers
in
same_type_layers
.
values
()])
min_num_layers
=
min
([
len
(
layers
)
for
layers
in
same_type_layers
.
values
()])
group_size
=
min_num_layers
max_num_layers
=
max
([
len
(
layers
)
for
layers
in
same_type_layers
.
values
()])
if
max_num_layers
<
min_num_layers
*
1.25
:
# If the number of layers is not much larger than the minimum number of layers,
# use the maximum number of layers as the group size to avoid too many padding
# layers. A typical example is gpt-oss-20b + eagle, with 12 sw + 13 full. We
# pad it to (13 sw, 13 full) instead of (12 sw, 24 full). 1.25 is just a
# magic number to avoid too many padding layers.
group_size
=
max_num_layers
grouped_layers
=
[]
grouped_layers
=
[]
for
layers
in
same_type_layers
.
values
():
for
layers
in
same_type_layers
.
values
():
num_padding_layers
=
group_size
-
len
(
layers
)
%
group_size
num_padding_layers
=
group_size
-
len
(
layers
)
%
group_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment