Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18961c5e
Unverified
Commit
18961c5e
authored
Nov 03, 2025
by
Thomas Parnell
Committed by
GitHub
Nov 03, 2025
Browse files
[Hybrid] Pass kernel block size to builders (#27753)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
470ad118
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
27 deletions
+62
-27
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-1
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+7
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+25
-6
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+25
-19
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
18961c5e
...
...
@@ -62,7 +62,11 @@ class FlashAttentionBackend(AttentionBackend):
@
staticmethod
def
get_supported_kernel_block_size
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
16
)]
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return
[
16
,
32
,
64
]
@
classmethod
def
validate_head_size
(
cls
,
head_size
:
int
)
->
None
:
...
...
vllm/v1/kv_cache_interface.py
View file @
18961c5e
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
,
replace
from
math
import
prod
import
torch
...
...
@@ -44,6 +44,12 @@ class KVCacheSpec:
"""
raise
NotImplementedError
def
copy_with_new_block_size
(
self
,
block_size
:
int
)
->
Self
:
"""
Create a new KVCacheSpec from self but replacing the block size.
"""
return
replace
(
self
,
block_size
=
block_size
)
@
classmethod
def
merge
(
cls
,
specs
:
list
[
Self
])
->
Self
:
"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
18961c5e
...
...
@@ -4039,16 +4039,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
->
list
[
AttentionGroup
]:
attn_groups
:
list
[
AttentionGroup
]
=
[]
for
(
attn_backend
,
kv_cache_spec
),
layer_names
in
attn_backends_map
.
items
():
attn_group
=
AttentionGroup
.
create_with_metadata_builders
(
attn_group
=
AttentionGroup
(
attn_backend
,
layer_names
,
kv_cache_spec
,
self
.
vllm_config
,
self
.
device
,
kv_cache_group_id
,
num_metadata_builders
=
1
if
not
self
.
parallel_config
.
enable_dbo
else
2
,
)
attn_groups
.
append
(
attn_group
)
...
...
@@ -4067,7 +4062,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for
i
,
attn_backend_map
in
enumerate
(
attention_backend_maps
):
self
.
attn_groups
.
append
(
create_attn_groups
(
attn_backend_map
,
i
))
def
initialize_metadata_builders
(
self
,
kv_cache_config
:
KVCacheConfig
,
kernel_block_sizes
:
list
[
int
]
)
->
None
:
"""
Create the metadata builders for all KV cache groups and attn groups.
"""
for
kv_cache_group_id
in
range
(
len
(
kv_cache_config
.
kv_cache_groups
)):
for
attn_group
in
self
.
attn_groups
[
kv_cache_group_id
]:
attn_group
.
create_metadata_builders
(
self
.
vllm_config
,
self
.
device
,
kernel_block_sizes
[
kv_cache_group_id
]
if
kv_cache_group_id
<
len
(
kernel_block_sizes
)
else
None
,
num_metadata_builders
=
1
if
not
self
.
parallel_config
.
enable_dbo
else
2
,
)
# Calculate reorder batch threshold (if needed)
# Note (tdoublep): do this *after* constructing builders,
# because some of them change the threshold at init time.
self
.
calculate_reorder_batch_threshold
()
def
_check_and_update_cudagraph_mode
(
...
...
@@ -4633,6 +4648,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
kernel_block_sizes
=
self
.
_prepare_kernel_block_sizes
(
kv_cache_config
)
# create metadata builders
self
.
initialize_metadata_builders
(
kv_cache_config
,
kernel_block_sizes
)
# Reinitialize need to after initialize_attn_backend
self
.
may_reinitialize_input_batch
(
kv_cache_config
,
kernel_block_sizes
)
kv_caches
=
self
.
initialize_kv_cache_tensors
(
...
...
vllm/v1/worker/utils.py
View file @
18961c5e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
import
torch
...
...
@@ -134,31 +134,37 @@ class MultiModalBudget:
@
dataclass
class
AttentionGroup
:
backend
:
type
[
AttentionBackend
]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders
:
list
[
AttentionMetadataBuilder
]
layer_names
:
list
[
str
]
kv_cache_spec
:
KVCacheSpec
kv_cache_group_id
:
int
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders
:
list
[
AttentionMetadataBuilder
]
=
field
(
default_factory
=
lambda
:
[]
)
@
staticmethod
def
create_with_metadata_builders
(
backend
:
type
[
AttentionBackend
],
layer_names
:
list
[
str
],
kv_cache_spec
:
KVCacheSpec
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
kv_cache_group_id
:
int
,
def
create_metadata_builders
(
self
,
vllm_config
,
device
,
kernel_block_size
:
int
|
None
,
num_metadata_builders
:
int
=
1
,
)
->
"AttentionGroup"
:
metadata_builders
=
[
backend
.
get_builder_cls
()(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
):
kv_cache_spec_builder
=
(
self
.
kv_cache_spec
.
copy_with_new_block_size
(
kernel_block_size
)
if
kernel_block_size
is
not
None
else
self
.
kv_cache_spec
)
self
.
metadata_builders
=
[
self
.
backend
.
get_builder_cls
()(
kv_cache_spec_builder
,
self
.
layer_names
,
vllm_config
,
device
,
)
for
_
in
range
(
num_metadata_builders
)
]
return
AttentionGroup
(
backend
,
metadata_builders
,
layer_names
,
kv_cache_spec
,
kv_cache_group_id
)
def
get_metadata_builder
(
self
,
ubatch_id
:
int
=
0
)
->
AttentionMetadataBuilder
:
assert
len
(
self
.
metadata_builders
)
>
ubatch_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment