Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2aaa4238
Unverified
Commit
2aaa4238
authored
Oct 02, 2025
by
Matthew Bonanni
Committed by
GitHub
Oct 02, 2025
Browse files
[Attention] Move Backend enum into registry (#25893)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
ad2d7880
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
39 additions
and
41 deletions
+39
-41
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-1
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+1
-1
vllm/model_executor/models/siglip2navit.py
vllm/model_executor/models/siglip2navit.py
+1
-1
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+2
-1
vllm/platforms/__init__.py
vllm/platforms/__init__.py
+0
-1
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+5
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+7
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+5
-26
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+7
-2
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+5
-2
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+5
-2
No files found.
vllm/model_executor/models/qwen2_vl.py
View file @
2aaa4238
...
...
@@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from
transformers.models.qwen2_vl.video_processing_qwen2_vl
import
(
Qwen2VLVideoProcessor
)
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
...
...
@@ -65,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
2aaa4238
...
...
@@ -43,6 +43,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize
as
video_smart_resize
)
from
transformers.video_utils
import
VideoMetadata
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
...
...
@@ -66,7 +67,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
...
...
vllm/model_executor/models/siglip2navit.py
View file @
2aaa4238
...
...
@@ -13,6 +13,7 @@ from torch.nn import functional as F
from
transformers
import
Siglip2VisionConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
...
...
@@ -22,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.platforms
import
_Backend
from
.vision
import
get_vit_attn_backend
...
...
vllm/model_executor/models/vision.py
View file @
2aaa4238
...
...
@@ -10,11 +10,12 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
import
torch
from
transformers
import
PretrainedConfig
from
vllm.attention.backends.registry
import
_Backend
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
)
from
vllm.logger
import
init_logger
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
...
...
vllm/platforms/__init__.py
View file @
2aaa4238
...
...
@@ -9,7 +9,6 @@ from vllm import envs
from
vllm.plugins
import
load_plugins_by_group
from
vllm.utils
import
resolve_obj_by_qualname
,
supports_xccl
from
.interface
import
_Backend
# noqa: F401
from
.interface
import
CpuArchEnum
,
Platform
,
PlatformEnum
logger
=
logging
.
getLogger
(
__name__
)
...
...
vllm/platforms/cpu.py
View file @
2aaa4238
...
...
@@ -15,13 +15,15 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.utils
import
DEFAULT_MAX_NUM_BATCHED_TOKENS
from
.interface
import
CpuArchEnum
,
Platform
,
PlatformEnum
,
_Backend
from
.interface
import
CpuArchEnum
,
Platform
,
PlatformEnum
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
VllmConfig
else
:
_Backend
=
None
VllmConfig
=
None
...
...
@@ -90,10 +92,11 @@ class CpuPlatform(Platform):
return
"cpu"
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
:
_Backend
,
head_size
:
int
,
def
get_attn_backend_cls
(
cls
,
selected_backend
:
"
_Backend
"
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
has_sink
:
bool
,
use_sparse
:
bool
)
->
str
:
from
vllm.attention.backends.registry
import
_Backend
if
selected_backend
and
selected_backend
!=
_Backend
.
TORCH_SDPA
:
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
if
use_mla
:
...
...
vllm/platforms/cuda.py
View file @
2aaa4238
...
...
@@ -20,10 +20,13 @@ import vllm.envs as envs
from
vllm.logger
import
init_logger
from
vllm.utils
import
cuda_device_count_stateless
,
import_pynvml
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
if
TYPE_CHECKING
:
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
ModelConfig
,
VllmConfig
else
:
_Backend
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -202,7 +205,8 @@ class CudaPlatformBase(Platform):
@
classmethod
def
get_vit_attn_backend
(
cls
,
head_size
:
int
,
dtype
:
torch
.
dtype
)
->
_Backend
:
dtype
:
torch
.
dtype
)
->
"_Backend"
:
from
vllm.attention.backends.registry
import
_Backend
# For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
...
...
@@ -230,6 +234,7 @@ class CudaPlatformBase(Platform):
def
get_attn_backend_cls
(
cls
,
selected_backend
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
,
has_sink
,
use_sparse
)
->
str
:
from
vllm.attention.backends.registry
import
_Backend
if
use_mla
:
if
not
use_v1
:
raise
RuntimeError
(
...
...
vllm/platforms/interface.py
View file @
2aaa4238
...
...
@@ -17,12 +17,14 @@ from vllm.inputs import ProcessorInputs, PromptType
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
FlexibleArgumentParser
else
:
_Backend
=
None
ModelConfig
=
None
VllmConfig
=
None
LoRARequest
=
None
...
...
@@ -38,30 +40,6 @@ def in_wsl() -> bool:
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
class
_Backend
(
enum
.
Enum
):
FLASH_ATTN
=
enum
.
auto
()
TRITON_ATTN
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
ROCM_AITER_MLA
=
enum
.
auto
()
# Supported by V1
ROCM_AITER_FA
=
enum
.
auto
()
# used for ViT attn backend
TORCH_SDPA
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
FLASHINFER_MLA
=
enum
.
auto
()
TRITON_MLA
=
enum
.
auto
()
# Supported by V1
CUTLASS_MLA
=
enum
.
auto
()
FLASHMLA
=
enum
.
auto
()
# Supported by V1
FLASH_ATTN_MLA
=
enum
.
auto
()
# Supported by V1
PALLAS
=
enum
.
auto
()
IPEX
=
enum
.
auto
()
DUAL_CHUNK_FLASH_ATTN
=
enum
.
auto
()
DIFFERENTIAL_FLASH_ATTN
=
enum
.
auto
()
NO_ATTENTION
=
enum
.
auto
()
FLEX_ATTENTION
=
enum
.
auto
()
TREE_ATTN
=
enum
.
auto
()
ROCM_ATTN
=
enum
.
auto
()
class
PlatformEnum
(
enum
.
Enum
):
CUDA
=
enum
.
auto
()
ROCM
=
enum
.
auto
()
...
...
@@ -187,11 +165,12 @@ class Platform:
@
classmethod
def
get_vit_attn_backend
(
cls
,
head_size
:
int
,
dtype
:
torch
.
dtype
)
->
_Backend
:
dtype
:
torch
.
dtype
)
->
"_Backend"
:
from
vllm.attention.backends.registry
import
_Backend
return
_Backend
.
TORCH_SDPA
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
:
_Backend
,
head_size
:
int
,
def
get_attn_backend_cls
(
cls
,
selected_backend
:
"
_Backend
"
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
has_sink
:
bool
,
use_sparse
:
bool
)
->
str
:
...
...
vllm/platforms/rocm.py
View file @
2aaa4238
...
...
@@ -14,10 +14,13 @@ import vllm.envs as envs
from
vllm.logger
import
init_logger
from
vllm.utils
import
cuda_device_count_stateless
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
if
TYPE_CHECKING
:
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
ModelConfig
,
VllmConfig
else
:
_Backend
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -182,7 +185,8 @@ class RocmPlatform(Platform):
@
classmethod
def
get_vit_attn_backend
(
cls
,
head_size
:
int
,
dtype
:
torch
.
dtype
)
->
_Backend
:
dtype
:
torch
.
dtype
)
->
"_Backend"
:
from
vllm.attention.backends.registry
import
_Backend
if
(
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MHA
and
on_gfx9
()):
# Note: AITER FA is only supported for Qwen-VL models.
...
...
@@ -196,6 +200,7 @@ class RocmPlatform(Platform):
def
get_attn_backend_cls
(
cls
,
selected_backend
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
,
has_sink
,
use_sparse
)
->
str
:
from
vllm.attention.backends.registry
import
_Backend
if
use_sparse
:
raise
NotImplementedError
(
"Sparse Attention is not supported on ROCm."
)
...
...
vllm/platforms/tpu.py
View file @
2aaa4238
...
...
@@ -11,9 +11,10 @@ from vllm.logger import init_logger
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
DEFAULT_MAX_NUM_BATCHED_TOKENS
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
from
.interface
import
Platform
,
PlatformEnum
if
TYPE_CHECKING
:
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
BlockSize
,
ModelConfig
,
VllmConfig
from
vllm.pooling_params
import
PoolingParams
else
:
...
...
@@ -21,6 +22,7 @@ else:
ModelConfig
=
None
VllmConfig
=
None
PoolingParams
=
None
_Backend
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -46,10 +48,11 @@ class TpuPlatform(Platform):
]
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
:
_Backend
,
head_size
:
int
,
def
get_attn_backend_cls
(
cls
,
selected_backend
:
"
_Backend
"
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
has_sink
,
use_sparse
)
->
str
:
from
vllm.attention.backends.registry
import
_Backend
if
use_sparse
:
raise
NotImplementedError
(
"Sparse Attention is not supported on TPU."
)
...
...
vllm/platforms/xpu.py
View file @
2aaa4238
...
...
@@ -10,13 +10,15 @@ import vllm.envs as envs
from
vllm.logger
import
init_logger
from
vllm.utils
import
DEFAULT_MAX_NUM_BATCHED_TOKENS
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
if
TYPE_CHECKING
:
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
ModelConfig
,
VllmConfig
else
:
ModelConfig
=
None
VllmConfig
=
None
_Backend
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -33,10 +35,11 @@ class XPUPlatform(Platform):
device_control_env_var
:
str
=
"ZE_AFFINITY_MASK"
@
classmethod
def
get_attn_backend_cls
(
cls
,
selected_backend
:
_Backend
,
head_size
:
int
,
def
get_attn_backend_cls
(
cls
,
selected_backend
:
"
_Backend
"
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
,
has_sink
:
bool
,
use_sparse
)
->
str
:
from
vllm.attention.backends.registry
import
_Backend
if
use_sparse
:
raise
NotImplementedError
(
"Sparse Attention is not supported on XPU."
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment