Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
77a73458
Unverified
Commit
77a73458
authored
Mar 09, 2026
by
Matthew Bonanni
Committed by
GitHub
Mar 09, 2026
Browse files
Reapply [Attention] Refactor `check_and_update_config` (#35122)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
5578f2a4
Changes
32
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
237 additions
and
256 deletions
+237
-256
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+40
-35
tests/models/multimodal/processing/test_tensor_schema.py
tests/models/multimodal/processing/test_tensor_schema.py
+4
-1
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+1
-1
vllm/config/cache.py
vllm/config/cache.py
+26
-8
vllm/config/vllm.py
vllm/config/vllm.py
+47
-46
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+0
-1
vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py
...ed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py
+0
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-3
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+0
-3
vllm/model_executor/layers/attention/chunked_local_attention.py
...odel_executor/layers/attention/chunked_local_attention.py
+6
-9
vllm/model_executor/layers/attention/cross_attention.py
vllm/model_executor/layers/attention/cross_attention.py
+0
-3
vllm/model_executor/layers/attention/encoder_only_attention.py
...model_executor/layers/attention/encoder_only_attention.py
+0
-3
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+12
-8
vllm/model_executor/layers/attention/static_sink_attention.py
.../model_executor/layers/attention/static_sink_attention.py
+3
-8
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+3
-4
vllm/model_executor/models/whisper_causal.py
vllm/model_executor/models/whisper_causal.py
+0
-3
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+7
-1
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+29
-117
vllm/platforms/interface.py
vllm/platforms/interface.py
+50
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+7
-1
No files found.
tests/kernels/attention/test_attention_selector.py
View file @
77a73458
...
...
@@ -6,7 +6,12 @@ from unittest.mock import patch
import
pytest
import
torch
from
vllm.config
import
AttentionConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
VllmConfig
,
set_current_vllm_config
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
...
...
@@ -84,12 +89,15 @@ def test_backend_selection(
"""Test attention backend selection with valid device-backend pairs."""
# Create AttentionConfig with the specified backend
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
[
name
])
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
cache_config
=
CacheConfig
(
block_size
=
block_size
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
,
cache_config
=
cache_config
)
with
set_current_vllm_config
(
vllm_config
):
if
device
==
"cpu"
:
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
block_size
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
)
assert
backend
.
get_name
()
==
"CPU_ATTN"
elif
device
==
"hip"
:
...
...
@@ -104,20 +112,16 @@ def test_backend_selection(
if
name
==
"TRITON_MLA"
and
block_size
==
1
:
# TRITON_MLA doesn't support block_size == 1
with
pytest
.
raises
(
ValueError
):
get_attn_backend
(
576
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
)
get_attn_backend
(
576
,
torch
.
float16
,
None
,
use_mla
=
use_mla
)
else
:
# Valid backend-block_size combination
backend
=
get_attn_backend
(
576
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
576
,
torch
.
float16
,
None
,
use_mla
=
use_mla
)
expected
=
name
assert
backend
.
get_name
()
==
expected
else
:
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
)
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
None
,
use_mla
=
use_mla
)
expected
=
"ROCM_ATTN"
assert
backend
.
get_name
()
==
expected
...
...
@@ -141,7 +145,7 @@ def test_backend_selection(
if
capability
[
0
]
!=
10
:
pytest
.
skip
(
"CUTLASS MLA is not supported on this platform"
)
backend
=
get_attn_backend
(
576
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
576
,
torch
.
float16
,
None
,
use_mla
=
use_mla
)
expected
=
"CUTLASS_MLA"
assert
backend
.
get_name
()
==
expected
...
...
@@ -156,7 +160,7 @@ def test_backend_selection(
"FlashInfer MLA only supports block_size 32 or 64"
)
backend
=
get_attn_backend
(
576
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
576
,
torch
.
float16
,
None
,
use_mla
=
use_mla
)
expected
=
"FLASHINFER_MLA"
assert
backend
.
get_name
()
==
expected
...
...
@@ -175,7 +179,6 @@ def test_backend_selection(
576
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
,
)
expected
=
name
...
...
@@ -190,27 +193,23 @@ def test_backend_selection(
"FlashAttention MLA not supported on this platform"
)
backend
=
get_attn_backend
(
576
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
576
,
torch
.
float16
,
None
,
use_mla
=
use_mla
)
expected
=
"FLASH_ATTN_MLA"
assert
backend
.
get_name
()
==
expected
else
:
# TRITON_MLA or other fallback
backend
=
get_attn_backend
(
576
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
576
,
torch
.
float16
,
None
,
use_mla
=
use_mla
)
expected
=
"TRITON_MLA"
assert
backend
.
get_name
()
==
expected
elif
name
==
"FLASHINFER"
:
backend
=
get_attn_backend
(
64
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
)
backend
=
get_attn_backend
(
64
,
torch
.
float16
,
None
,
use_mla
=
use_mla
)
expected
=
"FLASHINFER"
assert
backend
.
get_name
()
==
expected
elif
name
==
"FLASH_ATTN"
:
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
)
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
None
,
use_mla
=
use_mla
)
expected
=
"FLASH_ATTN"
assert
backend
.
get_name
()
==
expected
...
...
@@ -224,12 +223,12 @@ def test_fp32_fallback(device: str):
with
set_current_vllm_config
(
vllm_config
):
if
device
==
"cpu"
:
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
)
assert
backend
.
get_name
()
==
"CPU_ATTN"
elif
device
==
"cuda"
:
with
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
)
assert
backend
.
get_name
()
==
"FLEX_ATTENTION"
...
...
@@ -241,35 +240,40 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
)
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
FLASH_ATTN
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
cache_config
=
CacheConfig
(
block_size
=
16
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
,
cache_config
=
cache_config
)
with
set_current_vllm_config
(
vllm_config
):
# Unsupported CUDA arch
monkeypatch
.
setattr
(
torch
.
cuda
,
"get_device_capability"
,
lambda
_
=
None
:
(
7
,
5
))
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
)
assert
backend
.
get_name
()
!=
"FLASH_ATTN"
# Reset the monkeypatch for subsequent tests
monkeypatch
.
undo
()
# Unsupported data type
backend
=
get_attn_backend
(
16
,
torch
.
float8_e4m3fn
,
None
,
16
)
backend
=
get_attn_backend
(
16
,
torch
.
float8_e4m3fn
,
None
)
assert
backend
.
get_name
()
!=
"FLASH_ATTN"
# Unsupported kv cache data type
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
"fp8"
,
16
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
"fp8"
)
assert
backend
.
get_name
()
!=
"FLASH_ATTN"
# Unsupported block size
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
8
)
vllm_config
.
cache_config
.
block_size
=
8
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
)
assert
backend
.
get_name
()
!=
"FLASH_ATTN"
# flash-attn is not installed
import
sys
vllm_config
.
cache_config
.
block_size
=
16
original_module
=
sys
.
modules
.
get
(
"vllm_flash_attn"
)
monkeypatch
.
setitem
(
sys
.
modules
,
"vllm_flash_attn"
,
None
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
)
assert
backend
.
get_name
()
!=
"FLASH_ATTN"
# Restore the original module if it existed
...
...
@@ -279,7 +283,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch
.
delitem
(
sys
.
modules
,
"vllm_flash_attn"
,
raising
=
False
)
# Unsupported head size
backend
=
get_attn_backend
(
17
,
torch
.
float16
,
None
,
16
)
backend
=
get_attn_backend
(
17
,
torch
.
float16
,
None
)
assert
backend
.
get_name
()
!=
"FLASH_ATTN"
...
...
@@ -320,7 +324,7 @@ def test_auto_backend_selection_behavior():
set_current_vllm_config
(
vllm_config_auto
),
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()),
):
backend_auto
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
)
backend_auto
=
get_attn_backend
(
16
,
torch
.
float16
,
None
)
_cached_get_attn_backend
.
cache_clear
()
...
...
@@ -328,7 +332,7 @@ def test_auto_backend_selection_behavior():
set_current_vllm_config
(
vllm_config_none
),
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()),
):
backend_none
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
)
backend_none
=
get_attn_backend
(
16
,
torch
.
float16
,
None
)
# Both should select the same backend
assert
backend_auto
.
get_name
()
==
backend_none
.
get_name
()
...
...
@@ -358,7 +362,10 @@ def test_per_head_quant_scales_backend_selection(
backend
=
AttentionBackendEnum
[
backend_name
],
flash_attn_version
=
flash_attn_version
,
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
cache_config
=
CacheConfig
(
block_size
=
64
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
,
cache_config
=
cache_config
)
with
(
set_current_vllm_config
(
vllm_config
),
...
...
@@ -376,7 +383,6 @@ def test_per_head_quant_scales_backend_selection(
head_size
=
128
,
dtype
=
torch
.
float16
,
kv_cache_dtype
=
"fp8"
,
block_size
=
64
,
use_per_head_quant_scales
=
True
,
)
assert
backend
.
get_name
()
==
backend_name
...
...
@@ -386,7 +392,6 @@ def test_per_head_quant_scales_backend_selection(
head_size
=
128
,
dtype
=
torch
.
float16
,
kv_cache_dtype
=
"fp8"
,
block_size
=
64
,
use_per_head_quant_scales
=
True
,
)
assert
backend_name
in
str
(
exc_info
.
value
)
tests/models/multimodal/processing/test_tensor_schema.py
View file @
77a73458
...
...
@@ -13,6 +13,7 @@ import torch.nn as nn
from
PIL
import
Image
from
vllm.config
import
ModelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config.cache
import
CacheConfig
from
vllm.config.multimodal
import
(
AudioDummyOptions
,
BaseDummyOptions
,
...
...
@@ -131,7 +132,9 @@ def initialize_dummy_model(
):
temp_file
=
tempfile
.
mkstemp
()[
1
]
current_device
=
torch
.
get_default_device
()
vllm_config
=
VllmConfig
(
model_config
=
model_config
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(
block_size
=
16
)
)
with
set_current_vllm_config
(
vllm_config
=
vllm_config
):
init_distributed_environment
(
world_size
=
1
,
...
...
tests/v1/spec_decode/test_eagle.py
View file @
77a73458
...
...
@@ -80,7 +80,7 @@ def _create_proposer(
device
=
current_platform
.
device_type
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
cache_config
=
CacheConfig
(
block_size
=
16
),
speculative_config
=
speculative_config
,
device_config
=
DeviceConfig
(
device
=
device
),
parallel_config
=
ParallelConfig
(),
...
...
vllm/config/cache.py
View file @
77a73458
...
...
@@ -2,16 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
field
from
typing
import
Literal
from
typing
import
ClassVar
,
Literal
from
pydantic
import
Field
,
SkipValidation
,
field_validator
from
pydantic
import
Field
,
SkipValidation
,
field_validator
,
model_validator
from
vllm.config.utils
import
config
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
BlockSize
=
Literal
[
1
,
8
,
16
,
32
,
64
,
128
,
256
]
CacheDType
=
Literal
[
"auto"
,
"bfloat16"
,
...
...
@@ -31,12 +30,13 @@ KVOffloadingBackend = Literal["native", "lmcache"]
class
CacheConfig
:
"""Configuration for the KV cache."""
block_size
:
SkipValidation
[
BlockSize
]
=
None
# type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens.
DEFAULT_BLOCK_SIZE
:
ClassVar
[
int
]
=
16
This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_config()` based on the current
platform."""
block_size
:
SkipValidation
[
int
]
=
None
# type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens.
Accepts None (meaning "use default"). After construction, always int."""
user_specified_block_size
:
bool
=
field
(
default
=
False
,
init
=
False
)
"""Whether block_size was explicitly provided. Derived automatically."""
gpu_memory_utilization
:
float
=
Field
(
default
=
0.9
,
gt
=
0
,
le
=
1
)
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
...
...
@@ -169,6 +169,8 @@ class CacheConfig:
"prefix_caching_hash_algo"
,
"cpu_kvcache_space_bytes"
,
"mamba_page_size_padded"
,
"user_specified_block_size"
,
"_block_size_resolved"
,
# Post-init/derived counters
"num_gpu_blocks"
,
"num_cpu_blocks"
,
...
...
@@ -186,6 +188,22 @@ class CacheConfig:
# metrics info
return
{
key
:
str
(
value
)
for
key
,
value
in
self
.
__dict__
.
items
()}
_block_size_resolved
:
bool
=
field
(
default
=
False
,
init
=
False
)
"""Guard against pydantic re-running _apply_block_size_default."""
@
model_validator
(
mode
=
"after"
)
def
_apply_block_size_default
(
self
)
->
"CacheConfig"
:
# Pydantic re-runs validators when CacheConfig is nested inside
# another pydantic model (e.g. VllmConfig). Guard against that.
if
self
.
_block_size_resolved
:
return
self
object
.
__setattr__
(
self
,
"_block_size_resolved"
,
True
)
if
self
.
block_size
is
None
:
object
.
__setattr__
(
self
,
"block_size"
,
self
.
DEFAULT_BLOCK_SIZE
)
else
:
object
.
__setattr__
(
self
,
"user_specified_block_size"
,
True
)
return
self
@
field_validator
(
"cache_dtype"
,
mode
=
"after"
)
@
classmethod
def
_validate_cache_dtype
(
cls
,
cache_dtype
:
CacheDType
)
->
CacheDType
:
...
...
vllm/config/vllm.py
View file @
77a73458
...
...
@@ -1026,32 +1026,6 @@ class VllmConfig:
)
current_platform
.
check_and_update_config
(
self
)
# If DCP, ensure the block size is right.
if
self
.
parallel_config
.
decode_context_parallel_size
>
1
:
if
self
.
parallel_config
.
dcp_kv_cache_interleave_size
>
1
and
(
self
.
parallel_config
.
cp_kv_cache_interleave_size
!=
self
.
parallel_config
.
dcp_kv_cache_interleave_size
):
self
.
parallel_config
.
cp_kv_cache_interleave_size
=
(
self
.
parallel_config
.
dcp_kv_cache_interleave_size
)
logger
.
warning_once
(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert
(
self
.
parallel_config
.
cp_kv_cache_interleave_size
<=
self
.
cache_config
.
block_size
and
self
.
cache_config
.
block_size
%
self
.
parallel_config
.
cp_kv_cache_interleave_size
==
0
),
(
f
"Block_size(
{
self
.
cache_config
.
block_size
}
) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f
"(
{
self
.
parallel_config
.
cp_kv_cache_interleave_size
}
)."
)
# Do this after all the updates to compilation_config.mode
effective_dp_size
=
(
self
.
parallel_config
.
data_parallel_size
...
...
@@ -1219,26 +1193,6 @@ class VllmConfig:
# Default to enable HMA if not explicitly disabled by user or logic above.
self
.
scheduler_config
.
disable_hybrid_kv_cache_manager
=
False
if
self
.
cache_config
.
mamba_cache_mode
==
"align"
:
assert
(
self
.
cache_config
.
block_size
<=
self
.
scheduler_config
.
max_num_batched_tokens
),
(
"In Mamba cache align mode, block_size "
f
"(
{
self
.
cache_config
.
block_size
}
) must be <= "
"max_num_batched_tokens "
f
"(
{
self
.
scheduler_config
.
max_num_batched_tokens
}
)."
)
if
self
.
scheduler_config
.
long_prefill_token_threshold
>
0
:
assert
(
self
.
scheduler_config
.
long_prefill_token_threshold
>=
self
.
cache_config
.
block_size
)
assert
not
self
.
scheduler_config
.
disable_chunked_mm_input
,
(
"Chunked MM input is required because we need the flexibility to "
"schedule a multiple of block_size tokens even if they are in the "
"middle of a mm input"
)
if
self
.
compilation_config
.
debug_dump_path
:
self
.
compilation_config
.
debug_dump_path
=
(
self
.
compilation_config
.
debug_dump_path
.
absolute
().
expanduser
()
...
...
@@ -1673,6 +1627,53 @@ class VllmConfig:
f
"compilation_config=
{
self
.
compilation_config
!
r
}
"
)
def
validate_block_size
(
self
)
->
None
:
"""Validate block_size against DCP and mamba constraints.
Called after Platform.update_block_size_for_backend() has
finalised block_size.
"""
block_size
=
self
.
cache_config
.
block_size
# DCP interleave-size compatibility
if
self
.
parallel_config
.
decode_context_parallel_size
>
1
:
if
self
.
parallel_config
.
dcp_kv_cache_interleave_size
>
1
and
(
self
.
parallel_config
.
cp_kv_cache_interleave_size
!=
self
.
parallel_config
.
dcp_kv_cache_interleave_size
):
self
.
parallel_config
.
cp_kv_cache_interleave_size
=
(
self
.
parallel_config
.
dcp_kv_cache_interleave_size
)
logger
.
warning_once
(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert
(
self
.
parallel_config
.
cp_kv_cache_interleave_size
<=
block_size
and
block_size
%
self
.
parallel_config
.
cp_kv_cache_interleave_size
==
0
),
(
f
"Block_size(
{
block_size
}
) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f
"(
{
self
.
parallel_config
.
cp_kv_cache_interleave_size
}
)."
)
# Mamba cache align-mode constraints
if
self
.
cache_config
.
mamba_cache_mode
==
"align"
:
assert
block_size
<=
self
.
scheduler_config
.
max_num_batched_tokens
,
(
"In Mamba cache align mode, block_size "
f
"(
{
block_size
}
) must be <= "
"max_num_batched_tokens "
f
"(
{
self
.
scheduler_config
.
max_num_batched_tokens
}
)."
)
if
self
.
scheduler_config
.
long_prefill_token_threshold
>
0
:
assert
self
.
scheduler_config
.
long_prefill_token_threshold
>=
block_size
assert
not
self
.
scheduler_config
.
disable_chunked_mm_input
,
(
"Chunked MM input is required because we need the flexibility "
"to schedule a multiple of block_size tokens even if they are "
"in the middle of a mm input"
)
@
model_validator
(
mode
=
"after"
)
def
validate_mamba_block_size
(
self
)
->
"VllmConfig"
:
if
self
.
model_config
is
None
:
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
77a73458
...
...
@@ -500,7 +500,6 @@ def get_current_attn_backend(vllm_config: VllmConfig):
head_size
=
vllm_config
.
model_config
.
get_head_size
(),
dtype
=
vllm_config
.
model_config
.
dtype
,
kv_cache_dtype
=
vllm_config
.
cache_config
.
cache_dtype
,
block_size
=
vllm_config
.
cache_config
.
block_size
,
use_mla
=
vllm_config
.
model_config
.
use_mla
,
)
return
backend
vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py
View file @
77a73458
...
...
@@ -726,7 +726,6 @@ class MoRIIOConnectorWorker:
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
use_mla
=
self
.
use_mla
,
)
...
...
vllm/engine/arg_utils.py
View file @
77a73458
...
...
@@ -62,7 +62,6 @@ from vllm.config import (
get_attr_docs
,
)
from
vllm.config.cache
import
(
BlockSize
,
CacheDType
,
KVOffloadingBackend
,
MambaCacheMode
,
...
...
@@ -440,7 +439,7 @@ class EngineArgs:
max_parallel_loading_workers
:
int
|
None
=
(
ParallelConfig
.
max_parallel_loading_workers
)
block_size
:
BlockSize
=
CacheConfig
.
block_siz
e
block_size
:
int
|
None
=
Non
e
enable_prefix_caching
:
bool
|
None
=
None
prefix_caching_hash_algo
:
PrefixCachingHashAlgo
=
(
CacheConfig
.
prefix_caching_hash_algo
...
...
@@ -1521,7 +1520,7 @@ class EngineArgs:
)
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
# type: ignore[arg-type]
gpu_memory_utilization
=
self
.
gpu_memory_utilization
,
kv_cache_memory_bytes
=
self
.
kv_cache_memory_bytes
,
cache_dtype
=
resolved_cache_dtype
,
# type: ignore[arg-type]
...
...
vllm/model_executor/layers/attention/attention.py
View file @
77a73458
...
...
@@ -221,11 +221,9 @@ class Attention(nn.Module, AttentionLayerBase):
vllm_config
=
get_current_vllm_config
()
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
calculate_kv_scales
=
cache_config
.
calculate_kv_scales
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
calculate_kv_scales
=
False
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
...
...
@@ -275,7 +273,6 @@ class Attention(nn.Module, AttentionLayerBase):
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_mla
=
False
,
has_sink
=
self
.
has_sink
,
use_mm_prefix
=
self
.
use_mm_prefix
,
...
...
vllm/model_executor/layers/attention/chunked_local_attention.py
View file @
77a73458
...
...
@@ -30,9 +30,8 @@ from vllm.v1.kv_cache_interface import (
def
create_chunked_local_attention_backend
(
underlying_attn_backend
:
AttentionBackend
,
attention_chunk_size
:
int
,
block_size
:
int
,
)
->
type
[
AttentionBackend
]:
prefix
=
f
"ChunkedLocalAttention_
{
attention_chunk_size
}
_
{
block_size
}
_
"
prefix
=
f
"ChunkedLocalAttention_
{
attention_chunk_size
}
_"
underlying_builder
=
underlying_attn_backend
.
get_builder_cls
()
assert
issubclass
(
underlying_builder
,
AttentionMetadataBuilder
)
...
...
@@ -55,7 +54,9 @@ def create_chunked_local_attention_backend(
fast_build
:
bool
=
False
,
):
cm
,
make_virtual_batches_block_table
=
make_local_attention_virtual_batches
(
attention_chunk_size
,
common_attn_metadata
,
block_size
attention_chunk_size
,
common_attn_metadata
,
self
.
kv_cache_spec
.
block_size
,
)
metadata
=
super
().
build
(
common_prefix_len
,
cm
,
fast_build
)
metadata
.
make_virtual_batches_block_table
=
make_virtual_batches_block_table
...
...
@@ -94,16 +95,12 @@ class ChunkedLocalAttention(Attention):
dtype
=
torch
.
get_default_dtype
()
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
underlying_attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
)
underlying_attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
)
attn_backend
=
create_chunked_local_attention_backend
(
underlying_attn_backend
,
attention_chunk_size
,
block_size
underlying_attn_backend
,
attention_chunk_size
)
super
().
__init__
(
...
...
vllm/model_executor/layers/attention/cross_attention.py
View file @
77a73458
...
...
@@ -188,10 +188,8 @@ class CrossAttention(Attention):
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
if
attn_type
is
not
None
:
assert
attn_type
==
AttentionType
.
ENCODER_DECODER
,
(
...
...
@@ -202,7 +200,6 @@ class CrossAttention(Attention):
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
attn_type
=
AttentionType
.
ENCODER_DECODER
,
)
attn_backend
=
create_cross_attention_backend
(
underlying_attn_backend
)
...
...
vllm/model_executor/layers/attention/encoder_only_attention.py
View file @
77a73458
...
...
@@ -66,16 +66,13 @@ class EncoderOnlyAttention(Attention):
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
underlying_attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
attn_type
=
AttentionType
.
ENCODER_ONLY
,
)
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
77a73458
...
...
@@ -323,11 +323,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
calculate_kv_scales
=
cache_config
.
calculate_kv_scales
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
calculate_kv_scales
=
False
self
.
quant_config
=
quant_config
...
...
@@ -336,7 +334,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self
.
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_mla
=
True
,
use_sparse
=
use_sparse
,
num_heads
=
self
.
num_heads
,
...
...
@@ -449,17 +446,24 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
# Attributes for forward_impl method
self
.
chunked_prefill_workspace_size
=
(
MLACommonMetadataBuilder
.
determine_chunked_prefill_workspace_size
(
get_current_vllm_config
()
)
)
self
.
_vllm_config
=
get_current_vllm_config
()
self
.
_chunked_prefill_workspace_size
:
int
|
None
=
None
self
.
_decode_concat_quant_fp8_op
=
_DecodeConcatQuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
compile_native
=
True
,
)
@
property
def
chunked_prefill_workspace_size
(
self
)
->
int
:
if
self
.
_chunked_prefill_workspace_size
is
None
:
self
.
_chunked_prefill_workspace_size
=
(
MLACommonMetadataBuilder
.
determine_chunked_prefill_workspace_size
(
self
.
_vllm_config
)
)
return
self
.
_chunked_prefill_workspace_size
def
forward
(
self
,
q
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/attention/static_sink_attention.py
View file @
77a73458
...
...
@@ -126,17 +126,13 @@ class StaticSinkAttention(Attention, CustomOp):
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
if
attn_backend
is
not
None
:
underlying_attn_backend
=
attn_backend
else
:
underlying_attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
)
underlying_attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
)
attn_backend
=
create_static_sink_attention_backend
(
underlying_attn_backend
,
# type: ignore[arg-type]
sink_len
=
sink_len
,
...
...
@@ -153,7 +149,6 @@ class StaticSinkAttention(Attention, CustomOp):
CustomOp
.
__init__
(
self
)
self
.
sink_len
=
sink_len
self
.
block_size
=
block_size
self
.
sink_populated
=
False
self
.
sink_key
=
None
self
.
sink_value
=
None
...
...
@@ -212,12 +207,12 @@ class StaticSinkAttention(Attention, CustomOp):
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
:
# Block size may get updated after model loading, refresh it
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
# Should not be called for enc-dec or encoder-only attention.
assert
self
.
attn_type
==
AttentionType
.
DECODER
return
SinkFullAttentionSpec
(
block_size
=
block_size
,
block_size
=
self
.
block_size
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
head_size_v
=
self
.
head_size_v
,
...
...
vllm/model_executor/models/config.py
View file @
77a73458
...
...
@@ -217,10 +217,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
mamba_page_size
,
kernel_block_alignment_size
*
attn_page_size_1_token
)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if
cache_config
.
block_size
is
None
or
cache_config
.
block_size
<
attn_block_size
:
# override attention block size if it is too small,
# even if the user has explicitly set it
if
cache_config
.
block_size
<
attn_block_size
:
cache_config
.
block_size
=
attn_block_size
logger
.
info
(
"Setting attention block size to %d tokens "
...
...
vllm/model_executor/models/whisper_causal.py
View file @
77a73458
...
...
@@ -290,16 +290,13 @@ class WhisperCausalAttentionWithBlockPooling(Attention):
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
underlying_attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
attn_type
=
attn_type
,
)
attn_backend
=
create_whisper_attention_backend_with_block_pooling
(
...
...
vllm/platforms/cpu.py
View file @
77a73458
...
...
@@ -185,7 +185,7 @@ class CpuPlatform(Platform):
cache_config
=
vllm_config
.
cache_config
if
cache_config
.
block_size
is
None
:
if
not
cache_config
.
user_specified_
block_size
:
cache_config
.
block_size
=
128
if
cache_config
.
block_size
%
32
!=
0
:
...
...
@@ -361,6 +361,12 @@ class CpuPlatform(Platform):
vllm_config
.
scheduler_config
.
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
)
@
classmethod
def
update_block_size_for_backend
(
cls
,
vllm_config
:
"VllmConfig"
)
->
None
:
# TODO: CPU still sets block_size in check_and_update_config.
# Move that logic here so block_size is chosen by the backend.
pass
@
classmethod
def
get_allowed_cpu_core_node_list
(
cls
)
->
tuple
[
list
[
int
],
list
[
LogicalCPUInfo
]]:
assert
platform
.
system
()
==
"Linux"
...
...
vllm/platforms/cuda.py
View file @
77a73458
...
...
@@ -166,122 +166,12 @@ class CudaPlatformBase(Platform):
@
classmethod
def
check_and_update_config
(
cls
,
vllm_config
:
"VllmConfig"
)
->
None
:
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
parallel_config
=
vllm_config
.
parallel_config
model_config
=
vllm_config
.
model_config
if
parallel_config
.
worker_cls
==
"auto"
:
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
cache_config
=
vllm_config
.
cache_config
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
16
# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing
# Note: block_size is initialized in
# HybridAttentionMambaModelConfig.verify_and_update_config
# for models with both attention and mamba,
# and doesn't need to be reinitialized here
if
(
model_config
is
not
None
and
model_config
.
use_mla
and
cache_config
.
block_size
is
not
None
):
use_sparse
=
hasattr
(
vllm_config
.
model_config
.
hf_config
,
"index_topk"
)
# If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
use_flashmla
=
False
use_cutlass_mla
=
False
use_flashinfer_mla
=
False
use_flashmla_sparse
=
False
use_flashinfer_mla_sparse
=
False
from
vllm.v1.attention.ops.flashmla
import
is_flashmla_dense_supported
if
vllm_config
.
attention_config
.
backend
is
None
:
# Default case
hf_text_config
=
model_config
.
hf_text_config
qk_nope_head_dim
=
getattr
(
hf_text_config
,
"qk_nope_head_dim"
,
1
)
if
(
cls
.
is_device_capability_family
(
100
)
and
not
use_sparse
and
qk_nope_head_dim
==
128
):
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
# and only if qk_nope_head_dim == 128 (kernel constraint)
use_flashinfer_mla
=
True
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config
.
attention_config
.
backend
=
(
AttentionBackendEnum
.
FLASHINFER_MLA
)
elif
cls
.
is_device_capability_family
(
100
)
and
not
use_sparse
:
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
use_cutlass_mla
=
True
elif
is_flashmla_dense_supported
()[
0
]:
# Non-Blackwell with FlashMLA support
use_flashmla
=
True
else
:
# Fallback: will use Triton MLA or other compatible backend
pass
else
:
# Forced case
backend
=
vllm_config
.
attention_config
.
backend
use_flashmla
=
backend
==
AttentionBackendEnum
.
FLASHMLA
use_cutlass_mla
=
backend
==
AttentionBackendEnum
.
CUTLASS_MLA
use_flashinfer_mla
=
backend
==
AttentionBackendEnum
.
FLASHINFER_MLA
use_flashmla_sparse
=
backend
==
AttentionBackendEnum
.
FLASHMLA_SPARSE
use_flashinfer_mla_sparse
=
(
backend
==
AttentionBackendEnum
.
FLASHINFER_MLA_SPARSE
)
if
(
use_flashmla
and
is_flashmla_dense_supported
()[
0
]
and
cache_config
.
block_size
%
64
!=
0
):
cache_config
.
block_size
=
64
logger
.
info
(
"Forcing kv cache block size to 64 for FlashMLA backend."
)
if
use_cutlass_mla
and
cache_config
.
block_size
%
128
!=
0
:
cache_config
.
block_size
=
128
logger
.
info
(
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
)
if
(
use_flashinfer_mla
and
cache_config
.
block_size
!=
32
and
cache_config
.
block_size
%
64
!=
0
):
cache_config
.
block_size
=
64
logger
.
info
(
"Forcing kv cache block size to 64 for FlashInferMLA backend."
)
if
use_sparse
:
if
not
(
use_flashmla_sparse
or
use_flashinfer_mla_sparse
):
use_flashmla_sparse
=
True
if
use_flashmla_sparse
and
cache_config
.
block_size
!=
64
:
cache_config
.
block_size
=
64
logger
.
info
(
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
elif
use_flashinfer_mla_sparse
and
cache_config
.
block_size
not
in
(
32
,
64
,
):
cache_config
.
block_size
=
64
logger
.
info
(
"Forcing kv cache block size to 64 for FlashInferMLASparse "
"backend."
)
scheduler_config
=
vllm_config
.
scheduler_config
# Note: model_config may be None during testing
if
(
...
...
@@ -312,10 +202,10 @@ class CudaPlatformBase(Platform):
num_heads
:
int
|
None
=
None
,
)
->
tuple
[
list
[
tuple
[
"AttentionBackendEnum"
,
int
]],
dict
[
"AttentionBackendEnum"
,
list
[
str
]],
dict
[
"AttentionBackendEnum"
,
tuple
[
int
,
list
[
str
]]
]
,
]:
valid_backends_priorities
=
[]
invalid_reasons
=
{}
invalid_reasons
:
dict
[
AttentionBackendEnum
,
tuple
[
int
,
list
[
str
]]]
=
{}
backend_priorities
=
_get_backend_priorities
(
attn_selector_config
.
use_mla
,
...
...
@@ -332,7 +222,7 @@ class CudaPlatformBase(Platform):
except
ImportError
:
invalid_reasons_i
=
[
"ImportError"
]
if
invalid_reasons_i
:
invalid_reasons
[
backend
]
=
invalid_reasons_i
invalid_reasons
[
backend
]
=
(
priority
,
invalid_reasons_i
)
else
:
valid_backends_priorities
.
append
((
backend
,
priority
))
...
...
@@ -341,14 +231,13 @@ class CudaPlatformBase(Platform):
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
:
"AttentionBackendEnum"
,
selected_backend
:
"AttentionBackendEnum
| None
"
,
attn_selector_config
:
"AttentionSelectorConfig"
,
num_heads
:
int
|
None
=
None
,
)
->
str
:
device_capability
=
cls
.
get_device_capability
()
assert
device_capability
is
not
None
attn_selector_config
=
attn_selector_config
.
_replace
(
block_size
=
None
)
# First try checking just the selected backend, if there is one.
if
selected_backend
is
not
None
:
try
:
...
...
@@ -370,7 +259,7 @@ class CudaPlatformBase(Platform):
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities
,
invalid_reasons
=
cls
.
get_valid_backends
(
valid_backends_priorities
,
all_
invalid_reasons
=
cls
.
get_valid_backends
(
device_capability
=
device_capability
,
attn_selector_config
=
attn_selector_config
,
num_heads
=
num_heads
,
...
...
@@ -379,7 +268,7 @@ class CudaPlatformBase(Platform):
"{"
+
", "
.
join
(
f
"
{
backend
.
name
}
: [
{
', '
.
join
(
reasons
)
}
]"
for
backend
,
reasons
in
invalid_reasons
.
items
()
for
backend
,
(
_
,
reasons
)
in
all_
invalid_reasons
.
items
()
)
+
"}"
)
...
...
@@ -402,6 +291,29 @@ class CudaPlatformBase(Platform):
)
selected_index
=
sorted_indices
[
0
]
selected_backend
=
valid_backends_priorities
[
selected_index
][
0
]
selected_priority
=
valid_backends_priorities
[
selected_index
][
1
]
# If the user specified --block-size (but not --attention-backend),
# check whether that constraint precluded any higher-priority backends.
if
attn_selector_config
.
block_size
is
not
None
:
excluded
=
[
backend
for
backend
,
(
priority
,
reasons
)
in
all_invalid_reasons
.
items
()
if
priority
<
selected_priority
and
reasons
==
[
"block_size not supported"
]
]
if
excluded
:
names
=
", "
.
join
(
b
.
name
for
b
in
excluded
)
logger
.
warning
(
"--block-size %d precluded higher-priority backend(s) "
"%s. Using %s instead, which may result in reduced "
"performance. Consider removing --block-size to "
"auto-select the optimal block size."
,
attn_selector_config
.
block_size
,
names
,
selected_backend
.
name
,
)
logger
.
info_once
(
"Using %s attention backend out of potential backends: %s."
,
selected_backend
.
name
,
...
...
vllm/platforms/interface.py
View file @
77a73458
...
...
@@ -420,6 +420,56 @@ class Platform:
"""
pass
@
classmethod
def
update_block_size_for_backend
(
cls
,
vllm_config
:
"VllmConfig"
)
->
None
:
"""
Ensure block_size is compatible with the attention backend.
"""
from
vllm.config.cache
import
CacheConfig
cache_config
=
vllm_config
.
cache_config
if
cache_config
.
user_specified_block_size
:
# User specified --block-size; keep it.
return
model_config
=
vllm_config
.
model_config
# model_config may be None during testing.
# Skip hybrid models — their block_size is managed by
# HybridAttentionMambaModelConfig.
if
model_config
is
None
or
model_config
.
is_hybrid
:
cache_config
.
block_size
=
CacheConfig
.
DEFAULT_BLOCK_SIZE
return
from
vllm.config.vllm
import
(
get_layers_from_vllm_config
,
set_current_vllm_config
,
)
from
vllm.model_executor.layers.attention_layer_base
import
(
AttentionLayerBase
,
)
attn_layers
=
get_layers_from_vllm_config
(
vllm_config
,
AttentionLayerBase
,
# type: ignore[type-abstract]
)
if
not
attn_layers
:
cache_config
.
block_size
=
CacheConfig
.
DEFAULT_BLOCK_SIZE
return
first_layer
=
next
(
iter
(
attn_layers
.
values
()))
backend_cls
=
first_layer
.
get_attn_backend
()
with
set_current_vllm_config
(
vllm_config
):
preferred
=
backend_cls
.
get_preferred_block_size
(
CacheConfig
.
DEFAULT_BLOCK_SIZE
)
if
preferred
!=
CacheConfig
.
DEFAULT_BLOCK_SIZE
:
logger
.
info
(
"Setting kv cache block size to %d for %s backend."
,
preferred
,
backend_cls
.
get_name
(),
)
cache_config
.
block_size
=
preferred
@
classmethod
def
verify_model_arch
(
cls
,
model_arch
:
str
)
->
None
:
"""
...
...
vllm/platforms/rocm.py
View file @
77a73458
...
...
@@ -687,7 +687,7 @@ class RocmPlatform(Platform):
)
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
if
cache_config
and
cache_config
.
block_size
is
None
:
if
cache_config
and
not
cache_config
.
user_specified_
block_size
:
if
(
envs
.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
and
envs
.
VLLM_ROCM_USE_AITER
# NOTE: This block has been deprecated
...
...
@@ -707,6 +707,12 @@ class RocmPlatform(Platform):
if
parallel_config
.
worker_cls
==
"auto"
:
parallel_config
.
worker_cls
=
"vllm.v1.worker.gpu_worker.Worker"
@
classmethod
def
update_block_size_for_backend
(
cls
,
vllm_config
:
"VllmConfig"
)
->
None
:
# TODO: ROCm still sets block_size in check_and_update_config.
# Move that logic here so block_size is chosen by the backend.
pass
@
classmethod
def
verify_model_arch
(
cls
,
model_arch
:
str
)
->
None
:
if
model_arch
in
_ROCM_UNSUPPORTED_MODELS
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment