Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
df334868
Unverified
Commit
df334868
authored
Oct 31, 2025
by
Chen Zhang
Committed by
GitHub
Oct 31, 2025
Browse files
[Hybrid] A simpler algorithm to find kernel_block_size (#26476)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
0e0a638c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
149 additions
and
85 deletions
+149
-85
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+53
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+91
-84
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+5
-1
No files found.
tests/v1/worker/test_gpu_model_runner.py
View file @
df334868
...
@@ -6,6 +6,7 @@ import pytest
...
@@ -6,6 +6,7 @@ import pytest
import
torch
import
torch
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
from
vllm.attention.backends.abstract
import
MultipleOf
from
vllm.config
import
(
from
vllm.config
import
(
CacheConfig
,
CacheConfig
,
ModelConfig
,
ModelConfig
,
...
@@ -34,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
...
@@ -34,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.utils
import
AttentionGroup
BLOCK_SIZE
=
16
BLOCK_SIZE
=
16
NUM_BLOCKS
=
10
NUM_BLOCKS
=
10
...
@@ -181,6 +183,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
...
@@ -181,6 +183,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
).
all
()
).
all
()
def
_make_mock_backend_for_kernel_block_size
(
supported_sizes
:
list
[
int
|
MultipleOf
],
):
class
_MockBackend
:
@
staticmethod
def
get_supported_kernel_block_size
():
return
supported_sizes
return
_MockBackend
()
def
_make_kv_cache_spec
()
->
FullAttentionSpec
:
return
FullAttentionSpec
(
block_size
=
1
,
num_kv_heads
=
1
,
head_size
=
1
,
dtype
=
"float16"
)
def
test_select_common_block_size_prefers_manager_block_size
():
backend_a
=
_make_mock_backend_for_kernel_block_size
([
MultipleOf
(
32
)])
backend_b
=
_make_mock_backend_for_kernel_block_size
([
64
,
MultipleOf
(
16
)])
attn_groups
=
[
AttentionGroup
(
backend_a
,
[],
[],
_make_kv_cache_spec
(),
0
),
AttentionGroup
(
backend_b
,
[],
[],
_make_kv_cache_spec
(),
0
),
]
selected_size
=
GPUModelRunner
.
select_common_block_size
(
128
,
attn_groups
)
assert
selected_size
==
128
def
test_select_common_block_size_uses_largest_shared_int
():
backend_a
=
_make_mock_backend_for_kernel_block_size
([
128
,
64
])
backend_b
=
_make_mock_backend_for_kernel_block_size
([
64
,
32
])
attn_groups
=
[
AttentionGroup
(
backend_a
,
[],
[],
_make_kv_cache_spec
(),
0
),
AttentionGroup
(
backend_b
,
[],
[],
_make_kv_cache_spec
(),
0
),
]
selected_size
=
GPUModelRunner
.
select_common_block_size
(
256
,
attn_groups
)
assert
selected_size
==
64
def
test_select_common_block_size_no_valid_option
():
backend_a
=
_make_mock_backend_for_kernel_block_size
([
64
])
backend_b
=
_make_mock_backend_for_kernel_block_size
([
MultipleOf
(
16
)])
attn_groups
=
[
AttentionGroup
(
backend_a
,
[],
[],
_make_kv_cache_spec
(),
0
),
AttentionGroup
(
backend_b
,
[],
[],
_make_kv_cache_spec
(),
0
),
]
with
pytest
.
raises
(
ValueError
):
GPUModelRunner
.
select_common_block_size
(
48
,
attn_groups
)
def
test_update_states_new_request
(
model_runner
,
dist_init
):
def
test_update_states_new_request
(
model_runner
,
dist_init
):
req_id
=
"req_0"
req_id
=
"req_0"
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
df334868
...
@@ -3978,6 +3978,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -3978,6 +3978,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def
create_attn_groups
(
def
create_attn_groups
(
attn_backends_map
:
dict
[
AttentionGroupKey
,
list
[
str
]],
attn_backends_map
:
dict
[
AttentionGroupKey
,
list
[
str
]],
kv_cache_group_id
:
int
,
)
->
list
[
AttentionGroup
]:
)
->
list
[
AttentionGroup
]:
attn_groups
:
list
[
AttentionGroup
]
=
[]
attn_groups
:
list
[
AttentionGroup
]
=
[]
for
(
attn_backend
,
kv_cache_spec
),
layer_names
in
attn_backends_map
.
items
():
for
(
attn_backend
,
kv_cache_spec
),
layer_names
in
attn_backends_map
.
items
():
...
@@ -3987,6 +3988,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -3987,6 +3988,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_spec
,
kv_cache_spec
,
self
.
vllm_config
,
self
.
vllm_config
,
self
.
device
,
self
.
device
,
kv_cache_group_id
,
num_metadata_builders
=
1
num_metadata_builders
=
1
if
not
self
.
parallel_config
.
enable_dbo
if
not
self
.
parallel_config
.
enable_dbo
else
2
,
else
2
,
...
@@ -4005,8 +4007,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4005,8 +4007,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Resolve cudagraph_mode before actually initialize metadata_builders
# Resolve cudagraph_mode before actually initialize metadata_builders
self
.
_check_and_update_cudagraph_mode
(
attention_backend_set
)
self
.
_check_and_update_cudagraph_mode
(
attention_backend_set
)
for
attn_backend
s
_map
in
attention_backend_maps
:
for
i
,
attn_backend_map
in
enumerate
(
attention_backend_maps
)
:
self
.
attn_groups
.
append
(
create_attn_groups
(
attn_backend
s
_map
))
self
.
attn_groups
.
append
(
create_attn_groups
(
attn_backend_map
,
i
))
# Calculate reorder batch threshold (if needed)
# Calculate reorder batch threshold (if needed)
self
.
calculate_reorder_batch_threshold
()
self
.
calculate_reorder_batch_threshold
()
...
@@ -4156,87 +4158,81 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4156,87 +4158,81 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
return
self
.
reorder_batch_threshold
=
reduce
(
min_none_high
,
reorder_batch_thresholds
)
self
.
reorder_batch_threshold
=
reduce
(
min_none_high
,
reorder_batch_thresholds
)
def
_find_compatible_block_sizes
(
@
staticmethod
self
,
def
select_common_block_size
(
kv_manager_block_size
:
int
,
kv_manager_block_size
:
int
,
attn_groups
:
list
[
AttentionGroup
]
backend_cls
:
type
[
AttentionBackend
],
return_all
:
bool
=
False
,
)
->
list
[
int
]:
"""
Find compatible block sizes for a backend.
Args:
kv_manager_block_size: Physical block size of KV cache
backend_cls: Attention backend class
return_all: Return all compatible sizes if True, max size if False
Returns:
Compatible block size(s) based on return_all parameter
Raises:
ValueError: If no compatible block size found
"""
supported_block_size
=
backend_cls
.
get_supported_kernel_block_size
()
compatible_sizes
=
[]
for
block_size
in
supported_block_size
:
if
isinstance
(
block_size
,
int
):
if
kv_manager_block_size
%
block_size
==
0
:
compatible_sizes
.
append
(
block_size
)
elif
(
isinstance
(
block_size
,
MultipleOf
)
and
kv_manager_block_size
%
block_size
.
base
==
0
):
compatible_sizes
.
append
(
kv_manager_block_size
)
if
not
compatible_sizes
:
raise
ValueError
(
f
"No compatible block size for
{
kv_manager_block_size
}
"
)
return
compatible_sizes
if
return_all
else
[
max
(
compatible_sizes
)]
def
_select_common_block_size
(
self
,
kv_manager_block_size
:
int
,
attn_groups
:
list
[
AttentionGroup
]
)
->
int
:
)
->
int
:
"""
"""
Select common block size for all backends.
Select a block size that is supported by all backends and is a factor of
kv_manager_block_size.
If kv_manager_block_size is supported by all backends, return it directly.
Otherwise, return the max supported size.
Args:
Args:
kv_manager_block_size: Block size of KV cache
kv_manager_block_size: Block size of KV cache
attn_groups: List of attention groups
attn_groups: List of attention groups
Returns:
Returns:
Block size supported by all backends,
The selected block size
prioritizing cache_config.block_size
Raises:
Raises:
ValueError: If no
common
block size found
ValueError: If no
valid
block size found
"""
"""
all_backend_supports
=
[]
for
attn_group
in
attn_groups
:
def
block_size_is_supported
(
compatible_sizes
=
self
.
_find_compatible_block_sizes
(
backends
:
list
[
type
[
AttentionBackend
]],
block_size
:
int
kv_manager_block_size
,
attn_group
.
backend
,
return_all
=
True
)
->
bool
:
)
"""
supported_sizes
=
sorted
(
list
(
set
(
compatible_sizes
)),
reverse
=
True
)
Check if the block size is supported by all backends.
all_backend_supports
.
append
(
set
(
supported_sizes
))
"""
for
backend
in
backends
:
common_supported_sizes
=
set
.
intersection
(
*
all_backend_supports
)
is_supported
=
False
for
supported_size
in
backend
.
get_supported_kernel_block_size
():
if
not
common_supported_sizes
:
if
isinstance
(
supported_size
,
int
):
error_msg
=
f
"No common block size for
{
kv_manager_block_size
}
. "
if
block_size
==
supported_size
:
for
i
,
attn_group
in
enumerate
(
attn_groups
):
is_supported
=
True
supported
=
all_backend_supports
[
i
]
elif
isinstance
(
supported_size
,
MultipleOf
):
error_msg
+=
(
if
block_size
%
supported_size
.
base
==
0
:
f
"Backend
{
attn_group
.
backend
}
supports:
{
sorted
(
supported
)
}
. "
is_supported
=
True
)
else
:
raise
ValueError
(
error_msg
)
raise
ValueError
(
f
"Unknown supported size:
{
supported_size
}
"
)
if
not
is_supported
:
if
self
.
cache_config
.
block_size
in
common_supported_sizes
:
return
False
return
self
.
cache_config
.
block_size
return
True
backends
=
[
group
.
backend
for
group
in
attn_groups
]
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly
if
block_size_is_supported
(
backends
,
kv_manager_block_size
):
return
kv_manager_block_size
# Case 2: otherwise, the block_size must be an `int`-format supported size of
# at least one backend. Iterate over all `int`-format supported sizes in
# descending order and return the first one that is supported by all backends.
# Simple proof:
# If the supported size b is in MultipleOf(x_i) format for all attention
# backends i, and b a factor of kv_manager_block_size, then
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
# return kv_manager_block_size in case 1.
all_int_supported_sizes
=
set
(
supported_size
for
backend
in
backends
for
supported_size
in
backend
.
get_supported_kernel_block_size
()
if
isinstance
(
supported_size
,
int
)
)
return
max
(
common_supported_sizes
)
for
supported_size
in
sorted
(
all_int_supported_sizes
,
reverse
=
True
):
if
kv_manager_block_size
%
supported_size
!=
0
:
continue
if
block_size_is_supported
(
backends
,
supported_size
):
return
supported_size
raise
ValueError
(
f
"No common block size for
{
kv_manager_block_size
}
. "
)
def
may_reinitialize_input_batch
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
may_reinitialize_input_batch
(
self
,
kv_cache_config
:
KVCacheConfig
,
kernel_block_sizes
:
list
[
int
]
)
->
None
:
"""
"""
Re-initialize the input batch if the block sizes are different from
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
`[self.cache_config.block_size]`. This usually happens when there
...
@@ -4244,6 +4240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4244,6 +4240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Args:
Args:
kv_cache_config: The KV cache configuration.
kv_cache_config: The KV cache configuration.
kernel_block_sizes: The kernel block sizes for each KV cache group.
"""
"""
block_sizes
=
[
block_sizes
=
[
kv_cache_group
.
kv_cache_spec
.
block_size
kv_cache_group
.
kv_cache_spec
.
block_size
...
@@ -4251,9 +4248,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4251,9 +4248,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
not
isinstance
(
kv_cache_group
.
kv_cache_spec
,
EncoderOnlyAttentionSpec
)
if
not
isinstance
(
kv_cache_group
.
kv_cache_spec
,
EncoderOnlyAttentionSpec
)
]
]
# Generate kernel_block_sizes that matches each block_size
kernel_block_sizes
=
self
.
_prepare_kernel_block_sizes
(
kv_cache_config
)
if
block_sizes
!=
[
self
.
cache_config
.
block_size
]
or
kernel_block_sizes
!=
[
if
block_sizes
!=
[
self
.
cache_config
.
block_size
]
or
kernel_block_sizes
!=
[
self
.
cache_config
.
block_size
self
.
cache_config
.
block_size
]:
]:
...
@@ -4354,7 +4348,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4354,7 +4348,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# all backends in the group.
# all backends in the group.
attn_groups
=
self
.
attn_groups
[
kv_cache_group_id
]
attn_groups
=
self
.
attn_groups
[
kv_cache_group_id
]
kv_manager_block_size
=
kv_cache_group
.
kv_cache_spec
.
block_size
kv_manager_block_size
=
kv_cache_group
.
kv_cache_spec
.
block_size
selected_kernel_size
=
self
.
_
select_common_block_size
(
selected_kernel_size
=
self
.
select_common_block_size
(
kv_manager_block_size
,
attn_groups
kv_manager_block_size
,
attn_groups
)
)
kernel_block_sizes
.
append
(
selected_kernel_size
)
kernel_block_sizes
.
append
(
selected_kernel_size
)
...
@@ -4372,6 +4366,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4372,6 +4366,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
,
self
,
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
],
kv_cache_raw_tensors
:
dict
[
str
,
torch
.
Tensor
],
kernel_block_sizes
:
list
[
int
],
)
->
dict
[
str
,
torch
.
Tensor
]:
)
->
dict
[
str
,
torch
.
Tensor
]:
"""
"""
Reshape the KV cache tensors to the desired shape and dtype.
Reshape the KV cache tensors to the desired shape and dtype.
...
@@ -4380,6 +4375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4380,6 +4375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config: The KV cache config
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
correct size but uninitialized shape.
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
corresponding memory buffer for KV cache.
...
@@ -4389,6 +4385,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4389,6 +4385,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for
group
in
self
.
_kv_cache_spec_attn_group_iterator
():
for
group
in
self
.
_kv_cache_spec_attn_group_iterator
():
kv_cache_spec
=
group
.
kv_cache_spec
kv_cache_spec
=
group
.
kv_cache_spec
attn_backend
=
group
.
backend
attn_backend
=
group
.
backend
if
group
.
kv_cache_group_id
==
len
(
kernel_block_sizes
):
# There may be a last group for layers without kv cache.
continue
kernel_block_size
=
kernel_block_sizes
[
group
.
kv_cache_group_id
]
for
layer_name
in
group
.
layer_names
:
for
layer_name
in
group
.
layer_names
:
if
layer_name
in
self
.
runner_only_attn_layers
:
if
layer_name
in
self
.
runner_only_attn_layers
:
continue
continue
...
@@ -4397,24 +4397,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4397,24 +4397,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_blocks
=
raw_tensor
.
numel
()
//
kv_cache_spec
.
page_size_bytes
num_blocks
=
raw_tensor
.
numel
()
//
kv_cache_spec
.
page_size_bytes
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
has_attn
=
True
has_attn
=
True
kv_manager_block_size
=
kv_cache_spec
.
block_size
num_blocks_per_kv_block
=
(
kernel_size_list
=
self
.
_find_compatible_block_sizes
(
kv_cache_spec
.
block_size
//
kernel_block_size
kv_manager_block_size
,
attn_backend
,
return_all
=
False
)
)
kernel_size
=
kernel_size_list
[
0
]
num_blocks_per_kv_block
=
kv_manager_block_size
//
kernel_size
kernel_num_blocks
=
num_blocks
*
num_blocks_per_kv_block
kernel_num_blocks
=
num_blocks
*
num_blocks_per_kv_block
kv_cache_shape
=
attn_backend
.
get_kv_cache_shape
(
kv_cache_shape
=
attn_backend
.
get_kv_cache_shape
(
kernel_num_blocks
,
kernel_num_blocks
,
kernel_size
,
kernel_
block_
size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
,
kv_cache_spec
.
head_size
,
cache_dtype_str
=
self
.
cache_config
.
cache_dtype
,
cache_dtype_str
=
self
.
cache_config
.
cache_dtype
,
)
)
dtype
=
kv_cache_spec
.
dtype
dtype
=
kv_cache_spec
.
dtype
try
:
try
:
kv_cache_stride_order
=
attn_backend
.
get_kv_cache_stride_order
()
# noqa: E501
kv_cache_stride_order
=
attn_backend
.
get_kv_cache_stride_order
()
assert
len
(
kv_cache_stride_order
)
==
len
(
kv_cache_shape
)
assert
len
(
kv_cache_stride_order
)
==
len
(
kv_cache_shape
)
except
(
AttributeError
,
NotImplementedError
):
except
(
AttributeError
,
NotImplementedError
):
kv_cache_stride_order
=
tuple
(
range
(
len
(
kv_cache_shape
)))
kv_cache_stride_order
=
tuple
(
range
(
len
(
kv_cache_shape
)))
...
@@ -4497,13 +4494,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4497,13 +4494,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
def
initialize_kv_cache_tensors
(
def
initialize_kv_cache_tensors
(
self
,
kv_cache_config
:
KVCacheConfig
self
,
kv_cache_config
:
KVCacheConfig
,
kernel_block_sizes
:
list
[
int
]
)
->
dict
[
str
,
torch
.
Tensor
]:
)
->
dict
[
str
,
torch
.
Tensor
]:
"""
"""
Initialize the memory buffer for KV cache.
Initialize the memory buffer for KV cache.
Args:
Args:
kv_cache_config: The KV cache config
kv_cache_config: The KV cache config
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
corresponding memory buffer for KV cache.
...
@@ -4512,7 +4511,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4512,7 +4511,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_raw_tensors
=
self
.
_allocate_kv_cache_tensors
(
kv_cache_config
)
kv_cache_raw_tensors
=
self
.
_allocate_kv_cache_tensors
(
kv_cache_config
)
# Change the memory buffer to the desired shape
# Change the memory buffer to the desired shape
kv_caches
=
self
.
_reshape_kv_cache_tensors
(
kv_caches
=
self
.
_reshape_kv_cache_tensors
(
kv_cache_config
,
kv_cache_raw_tensors
kv_cache_config
,
kv_cache_raw_tensors
,
kernel_block_sizes
)
)
# Set up cross-layer KV cache sharing
# Set up cross-layer KV cache sharing
...
@@ -4571,9 +4570,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -4571,9 +4570,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
may_add_encoder_only_layers_to_kv_cache_config
()
self
.
may_add_encoder_only_layers_to_kv_cache_config
()
self
.
maybe_add_kv_sharing_layers_to_kv_cache_groups
(
kv_cache_config
)
self
.
maybe_add_kv_sharing_layers_to_kv_cache_groups
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
)
self
.
initialize_attn_backend
(
kv_cache_config
)
# The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention
# backends for that group only supports block_size 64, we will return
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
kernel_block_sizes
=
self
.
_prepare_kernel_block_sizes
(
kv_cache_config
)
# Reinitialize need to after initialize_attn_backend
# Reinitialize need to after initialize_attn_backend
self
.
may_reinitialize_input_batch
(
kv_cache_config
)
self
.
may_reinitialize_input_batch
(
kv_cache_config
,
kernel_block_sizes
)
kv_caches
=
self
.
initialize_kv_cache_tensors
(
kv_cache_config
)
kv_caches
=
self
.
initialize_kv_cache_tensors
(
kv_cache_config
,
kernel_block_sizes
)
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
...
...
vllm/v1/worker/utils.py
View file @
df334868
...
@@ -140,6 +140,7 @@ class AttentionGroup:
...
@@ -140,6 +140,7 @@ class AttentionGroup:
metadata_builders
:
list
[
AttentionMetadataBuilder
]
metadata_builders
:
list
[
AttentionMetadataBuilder
]
layer_names
:
list
[
str
]
layer_names
:
list
[
str
]
kv_cache_spec
:
KVCacheSpec
kv_cache_spec
:
KVCacheSpec
kv_cache_group_id
:
int
@
staticmethod
@
staticmethod
def
create_with_metadata_builders
(
def
create_with_metadata_builders
(
...
@@ -148,13 +149,16 @@ class AttentionGroup:
...
@@ -148,13 +149,16 @@ class AttentionGroup:
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
KVCacheSpec
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
device
:
torch
.
device
,
kv_cache_group_id
:
int
,
num_metadata_builders
:
int
=
1
,
num_metadata_builders
:
int
=
1
,
)
->
"AttentionGroup"
:
)
->
"AttentionGroup"
:
metadata_builders
=
[
metadata_builders
=
[
backend
.
get_builder_cls
()(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
backend
.
get_builder_cls
()(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
for
_
in
range
(
num_metadata_builders
)
for
_
in
range
(
num_metadata_builders
)
]
]
return
AttentionGroup
(
backend
,
metadata_builders
,
layer_names
,
kv_cache_spec
)
return
AttentionGroup
(
backend
,
metadata_builders
,
layer_names
,
kv_cache_spec
,
kv_cache_group_id
)
def
get_metadata_builder
(
self
,
ubatch_id
:
int
=
0
)
->
AttentionMetadataBuilder
:
def
get_metadata_builder
(
self
,
ubatch_id
:
int
=
0
)
->
AttentionMetadataBuilder
:
assert
len
(
self
.
metadata_builders
)
>
ubatch_id
assert
len
(
self
.
metadata_builders
)
>
ubatch_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment