Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
76879cc1
Unverified
Commit
76879cc1
authored
Oct 08, 2025
by
Matthew Bonanni
Committed by
GitHub
Oct 08, 2025
Browse files
[Attention] Implement universal BACKEND_MAP (#25900)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
b25d7b56
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
119 additions
and
75 deletions
+119
-75
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+2
-2
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+1
-1
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+2
-2
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+3
-3
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+14
-38
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+5
-5
tests/v1/spec_decode/test_mtp.py
tests/v1/spec_decode/test_mtp.py
+2
-2
tests/v1/spec_decode/test_tree_attention.py
tests/v1/spec_decode/test_tree_attention.py
+2
-2
vllm/attention/backends/registry.py
vllm/attention/backends/registry.py
+83
-2
vllm/attention/layer.py
vllm/attention/layer.py
+2
-2
vllm/attention/selector.py
vllm/attention/selector.py
+1
-14
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+2
-2
No files found.
tests/kernels/attention/test_attention_selector.py
View file @
76879cc1
...
@@ -34,7 +34,7 @@ DEVICE_MLA_BACKENDS = {
...
@@ -34,7 +34,7 @@ DEVICE_MLA_BACKENDS = {
DEVICE_REGULAR_ATTN_BACKENDS
=
{
DEVICE_REGULAR_ATTN_BACKENDS
=
{
"cuda"
:
[
"XFORMERS"
,
"FLASHINFER"
,
"FLASH_ATTN"
],
"cuda"
:
[
"XFORMERS"
,
"FLASHINFER"
,
"FLASH_ATTN"
],
"hip"
:
[
"ROCM_
FLASH
"
],
"hip"
:
[
"ROCM_
ATTN
"
],
"cpu"
:
[
"TORCH_SDPA"
],
"cpu"
:
[
"TORCH_SDPA"
],
}
}
...
@@ -122,7 +122,7 @@ def test_env(
...
@@ -122,7 +122,7 @@ def test_env(
backend
=
get_attn_backend
(
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
16
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
)
)
expected
=
"
TRITON
_ATTN"
expected
=
"
ROCM
_ATTN"
assert
backend
.
get_name
()
==
expected
assert
backend
.
get_name
()
==
expected
elif
device
==
"cuda"
:
elif
device
==
"cuda"
:
...
...
tests/kernels/attention/test_rocm_attention_selector.py
View file @
76879cc1
...
@@ -18,7 +18,7 @@ def clear_cache():
...
@@ -18,7 +18,7 @@ def clear_cache():
@
pytest
.
mark
.
skip
(
reason
=
"Skipped for now. Should be revisited."
)
@
pytest
.
mark
.
skip
(
reason
=
"Skipped for now. Should be revisited."
)
def
test_selector
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_selector
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_
FLASH
"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_
ATTN
"
)
# Set the current platform to ROCm using monkeypatch
# Set the current platform to ROCm using monkeypatch
monkeypatch
.
setattr
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
())
monkeypatch
.
setattr
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
())
...
...
tests/v1/attention/test_attention_backends.py
View file @
76879cc1
...
@@ -14,7 +14,7 @@ from tests.v1.attention.utils import (
...
@@ -14,7 +14,7 @@ from tests.v1.attention.utils import (
create_common_attn_metadata
,
create_common_attn_metadata
,
create_standard_kv_cache_spec
,
create_standard_kv_cache_spec
,
create_vllm_config
,
create_vllm_config
,
get_attention_backend
,
try_
get_attention_backend
,
)
)
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
...
@@ -214,7 +214,7 @@ def run_attention_backend(
...
@@ -214,7 +214,7 @@ def run_attention_backend(
actual_backend
=
_Backend
.
FLEX_ATTENTION
actual_backend
=
_Backend
.
FLEX_ATTENTION
use_direct_block_mask
=
False
use_direct_block_mask
=
False
builder_cls
,
impl_cls
=
get_attention_backend
(
actual_backend
)
builder_cls
,
impl_cls
=
try_
get_attention_backend
(
actual_backend
)
# Mock flashinfer's get_per_layer_parameters if needed
# Mock flashinfer's get_per_layer_parameters if needed
if
actual_backend
==
_Backend
.
FLASHINFER
:
if
actual_backend
==
_Backend
.
FLASHINFER
:
...
...
tests/v1/attention/test_mla_backends.py
View file @
76879cc1
...
@@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
...
@@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
create_common_attn_metadata
,
create_common_attn_metadata
,
create_standard_kv_cache_spec
,
create_standard_kv_cache_spec
,
create_vllm_config
,
create_vllm_config
,
get_attention_backend
,
try_
get_attention_backend
,
)
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
...
@@ -239,7 +239,7 @@ def run_attention_backend(
...
@@ -239,7 +239,7 @@ def run_attention_backend(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Run attention computation using the specified backend's AttentionImpl."""
"""Run attention computation using the specified backend's AttentionImpl."""
builder_cls
,
impl_cls
=
get_attention_backend
(
backend
)
builder_cls
,
impl_cls
=
try_
get_attention_backend
(
backend
)
# Build metadata
# Build metadata
builder
=
builder_cls
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
builder
=
builder_cls
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
...
@@ -400,7 +400,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -400,7 +400,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# Determine if this is decode or prefill
# Determine if this is decode or prefill
is_decode
=
[]
is_decode
=
[]
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
builder_cls
,
_
=
get_attention_backend
(
backend
)
builder_cls
,
_
=
try_
get_attention_backend
(
backend
)
is_decode
.
append
(
q_len
<=
builder_cls
.
reorder_batch_threshold
)
is_decode
.
append
(
q_len
<=
builder_cls
.
reorder_batch_threshold
)
# Split q into nope and rope components
# Split q into nope and rope components
...
...
tests/v1/attention/utils.py
View file @
76879cc1
...
@@ -8,7 +8,8 @@ from typing import Optional, Union
...
@@ -8,7 +8,8 @@ from typing import Optional, Union
import
pytest
import
pytest
import
torch
import
torch
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.abstract
import
AttentionImpl
from
vllm.attention.backends.registry
import
_Backend
,
backend_to_class_str
from
vllm.config
import
(
from
vllm.config
import
(
CacheConfig
,
CacheConfig
,
CompilationConfig
,
CompilationConfig
,
...
@@ -20,9 +21,11 @@ from vllm.config import (
...
@@ -20,9 +21,11 @@ from vllm.config import (
VllmConfig
,
VllmConfig
,
)
)
from
vllm.config.model
import
ModelDType
from
vllm.config.model
import
ModelDType
from
vllm.platforms
import
current_platform
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
...
@@ -117,44 +120,17 @@ def create_common_attn_metadata(
...
@@ -117,44 +120,17 @@ def create_common_attn_metadata(
)
)
def
get_attention_backend
(
backend_name
:
_Backend
):
def
try_get_attention_backend
(
"""Set up attention backend classes for testing.
backend
:
_Backend
,
)
->
tuple
[
type
[
AttentionMetadataBuilder
],
type
[
AttentionImpl
]]:
Args:
"""Try to get the attention backend class, skipping test if not found."""
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
backend_class_str
=
backend_to_class_str
(
backend
)
vllm_config: VllmConfig instance
Returns:
Tuple of (backend_builder_class, backend_impl_class)
"""
backend_map
=
{
_Backend
.
FLASH_ATTN
:
(
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if
current_platform
.
is_cuda
()
else
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
),
_Backend
.
FLASHINFER
:
"vllm.v1.attention.backends.flashinfer.FlashInferBackend"
,
_Backend
.
FLEX_ATTENTION
:
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
,
# noqa: E501
_Backend
.
TRITON_ATTN
:
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
,
# noqa: E501
_Backend
.
TREE_ATTN
:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
,
_Backend
.
XFORMERS
:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
,
# noqa: E501
_Backend
.
CUTLASS_MLA
:
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
,
# noqa: E501
_Backend
.
FLASHMLA
:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
,
_Backend
.
FLASH_ATTN_MLA
:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
,
# noqa: E501
_Backend
.
FLASHINFER_MLA
:
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
,
# noqa: E501
_Backend
.
TRITON_MLA
:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
,
# noqa: E501
}
if
backend_name
not
in
backend_map
:
raise
ValueError
(
f
"Unknown backend:
{
backend_name
}
"
)
backend_class_name
=
backend_map
[
backend_name
]
try
:
try
:
backend_class
=
resolve_obj_by_qualname
(
backend_class_
name
)
backend_class
=
resolve_obj_by_qualname
(
backend_class_
str
)
return
backend_class
.
get_builder_cls
(),
backend_class
.
get_impl_cls
()
return
backend_class
.
get_builder_cls
(),
backend_class
.
get_impl_cls
()
except
ImportError
as
e
:
except
ImportError
as
e
:
pytest
.
skip
(
f
"
{
backend_name
}
not available:
{
e
}
"
)
pytest
.
skip
(
f
"
{
backend_class_str
}
not available:
{
e
}
"
)
raise
AssertionError
(
"unreachable"
)
from
None
def
create_standard_kv_cache_spec
(
vllm_config
:
VllmConfig
)
->
FullAttentionSpec
:
def
create_standard_kv_cache_spec
(
vllm_config
:
VllmConfig
)
->
FullAttentionSpec
:
...
...
tests/v1/spec_decode/test_eagle.py
View file @
76879cc1
...
@@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
...
@@ -12,7 +12,7 @@ from tests.v1.attention.utils import (
BatchSpec
,
BatchSpec
,
create_common_attn_metadata
,
create_common_attn_metadata
,
create_standard_kv_cache_spec
,
create_standard_kv_cache_spec
,
get_attention_backend
,
try_
get_attention_backend
,
)
)
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
(
from
vllm.config
import
(
...
@@ -535,11 +535,11 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
...
@@ -535,11 +535,11 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
sampling_metadata
=
mock
.
MagicMock
()
sampling_metadata
=
mock
.
MagicMock
()
if
attn_backend
==
"FLASH_ATTN"
:
if
attn_backend
==
"FLASH_ATTN"
:
attn_metadata_builder_cls
,
_
=
get_attention_backend
(
_Backend
.
FLASH_ATTN
)
attn_metadata_builder_cls
,
_
=
try_
get_attention_backend
(
_Backend
.
FLASH_ATTN
)
elif
attn_backend
==
"TRITON_ATTN"
:
elif
attn_backend
==
"TRITON_ATTN"
:
attn_metadata_builder_cls
,
_
=
get_attention_backend
(
_Backend
.
TRITON_ATTN
)
attn_metadata_builder_cls
,
_
=
try_
get_attention_backend
(
_Backend
.
TRITON_ATTN
)
elif
attn_backend
==
"TREE_ATTN"
:
elif
attn_backend
==
"TREE_ATTN"
:
attn_metadata_builder_cls
,
_
=
get_attention_backend
(
_Backend
.
TREE_ATTN
)
attn_metadata_builder_cls
,
_
=
try_
get_attention_backend
(
_Backend
.
TREE_ATTN
)
else
:
else
:
raise
ValueError
(
f
"Unsupported attention backend:
{
attn_backend
}
"
)
raise
ValueError
(
f
"Unsupported attention backend:
{
attn_backend
}
"
)
...
@@ -674,7 +674,7 @@ def test_propose_tree(spec_token_tree):
...
@@ -674,7 +674,7 @@ def test_propose_tree(spec_token_tree):
proposer
.
attn_layer_names
=
[
"layer.0"
]
proposer
.
attn_layer_names
=
[
"layer.0"
]
# Get the tree attention metadata builder.
# Get the tree attention metadata builder.
attn_metadata_builder_cls
,
_
=
get_attention_backend
(
_Backend
.
TREE_ATTN
)
attn_metadata_builder_cls
,
_
=
try_
get_attention_backend
(
_Backend
.
TREE_ATTN
)
attn_metadata_builder
=
attn_metadata_builder_cls
(
attn_metadata_builder
=
attn_metadata_builder_cls
(
kv_cache_spec
=
create_standard_kv_cache_spec
(
proposer
.
vllm_config
),
kv_cache_spec
=
create_standard_kv_cache_spec
(
proposer
.
vllm_config
),
layer_names
=
proposer
.
attn_layer_names
,
layer_names
=
proposer
.
attn_layer_names
,
...
...
tests/v1/spec_decode/test_mtp.py
View file @
76879cc1
...
@@ -10,7 +10,7 @@ from tests.v1.attention.utils import (
...
@@ -10,7 +10,7 @@ from tests.v1.attention.utils import (
BatchSpec
,
BatchSpec
,
create_common_attn_metadata
,
create_common_attn_metadata
,
create_standard_kv_cache_spec
,
create_standard_kv_cache_spec
,
get_attention_backend
,
try_
get_attention_backend
,
)
)
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
(
from
vllm.config
import
(
...
@@ -177,7 +177,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
...
@@ -177,7 +177,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
sampling_metadata
=
mock
.
MagicMock
()
sampling_metadata
=
mock
.
MagicMock
()
# Setup attention metadata
# Setup attention metadata
attn_metadata_builder_cls
,
_
=
get_attention_backend
(
_Backend
.
FLASH_ATTN
)
attn_metadata_builder_cls
,
_
=
try_
get_attention_backend
(
_Backend
.
FLASH_ATTN
)
attn_metadata_builder
=
attn_metadata_builder_cls
(
attn_metadata_builder
=
attn_metadata_builder_cls
(
kv_cache_spec
=
create_standard_kv_cache_spec
(
proposer
.
vllm_config
),
kv_cache_spec
=
create_standard_kv_cache_spec
(
proposer
.
vllm_config
),
...
...
tests/v1/spec_decode/test_tree_attention.py
View file @
76879cc1
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
from
tests.v1.attention.utils
import
(
from
tests.v1.attention.utils
import
(
create_standard_kv_cache_spec
,
create_standard_kv_cache_spec
,
create_vllm_config
,
create_vllm_config
,
get_attention_backend
,
try_
get_attention_backend
,
)
)
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
...
@@ -63,7 +63,7 @@ def forward_attention(
...
@@ -63,7 +63,7 @@ def forward_attention(
# Build common metadata.
# Build common metadata.
model_name
=
"meta-llama/Meta-Llama-3-8B"
model_name
=
"meta-llama/Meta-Llama-3-8B"
builder_cls
,
impl_cls
=
get_attention_backend
(
backend
)
builder_cls
,
impl_cls
=
try_
get_attention_backend
(
backend
)
vllm_config
=
create_vllm_config
(
model_name
=
model_name
,
max_model_len
=
max
(
seq_lens
))
vllm_config
=
create_vllm_config
(
model_name
=
model_name
,
max_model_len
=
max
(
seq_lens
))
if
spec_token_tree
is
not
None
:
if
spec_token_tree
is
not
None
:
# Create speculative config if token tree is specified.
# Create speculative config if token tree is specified.
...
...
vllm/attention/backends/registry.py
View file @
76879cc1
...
@@ -3,13 +3,16 @@
...
@@ -3,13 +3,16 @@
"""Attention backend registry"""
"""Attention backend registry"""
import
enum
import
enum
from
typing
import
Optional
from
vllm.utils
import
resolve_obj_by_qualname
class
_Backend
(
enum
.
Enum
):
class
_Backend
(
enum
.
Enum
):
FLASH_ATTN
=
enum
.
auto
()
FLASH_ATTN
=
enum
.
auto
()
TRITON_ATTN
=
enum
.
auto
()
TRITON_ATTN
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
ROCM_
FLASH
=
enum
.
auto
()
ROCM_
ATTN
=
enum
.
auto
()
ROCM_AITER_MLA
=
enum
.
auto
()
ROCM_AITER_MLA
=
enum
.
auto
()
ROCM_AITER_FA
=
enum
.
auto
()
# used for ViT attn backend
ROCM_AITER_FA
=
enum
.
auto
()
# used for ViT attn backend
TORCH_SDPA
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
...
@@ -24,5 +27,83 @@ class _Backend(enum.Enum):
...
@@ -24,5 +27,83 @@ class _Backend(enum.Enum):
NO_ATTENTION
=
enum
.
auto
()
NO_ATTENTION
=
enum
.
auto
()
FLEX_ATTENTION
=
enum
.
auto
()
FLEX_ATTENTION
=
enum
.
auto
()
TREE_ATTN
=
enum
.
auto
()
TREE_ATTN
=
enum
.
auto
()
ROCM_ATTN
=
enum
.
auto
()
ROCM_AITER_UNIFIED_ATTN
=
enum
.
auto
()
ROCM_AITER_UNIFIED_ATTN
=
enum
.
auto
()
BACKEND_MAP
=
{
_Backend
.
FLASH_ATTN
:
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
,
# noqa: E501
_Backend
.
TRITON_ATTN
:
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
,
# noqa: E501
_Backend
.
XFORMERS
:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend"
,
# noqa: E501
_Backend
.
ROCM_ATTN
:
"vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
,
# noqa: E501
_Backend
.
ROCM_AITER_MLA
:
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
,
# noqa: E501
_Backend
.
ROCM_AITER_FA
:
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
,
# noqa: E501
_Backend
.
TORCH_SDPA
:
"vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
,
# noqa: E501
_Backend
.
FLASHINFER
:
"vllm.v1.attention.backends.flashinfer.FlashInferBackend"
,
# noqa: E501
_Backend
.
FLASHINFER_MLA
:
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
,
# noqa: E501
_Backend
.
TRITON_MLA
:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
,
# noqa: E501
_Backend
.
CUTLASS_MLA
:
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
,
# noqa: E501
_Backend
.
FLASHMLA
:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
,
# noqa: E501
_Backend
.
FLASH_ATTN_MLA
:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
,
# noqa: E501
_Backend
.
PALLAS
:
"vllm.v1.attention.backends.pallas.PallasAttentionBackend"
,
# noqa: E501
_Backend
.
FLEX_ATTENTION
:
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
,
# noqa: E501
_Backend
.
TREE_ATTN
:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
,
# noqa: E501
_Backend
.
ROCM_AITER_UNIFIED_ATTN
:
"vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
,
# noqa: E501
}
def
register_attn_backend
(
backend
:
_Backend
,
class_path
:
Optional
[
str
]
=
None
):
"""
Decorator: register a custom attention backend into BACKEND_MAPPING.
- If class_path is provided, use it.
- Otherwise, auto-generate from the class object.
Validation: only checks if 'backend' is a valid _Backend enum member.
Overwriting existing mappings is allowed. This enables other hardware
platforms to plug in custom out-of-tree backends.
"""
if
not
isinstance
(
backend
,
_Backend
):
raise
ValueError
(
f
"
{
backend
}
is not a valid _Backend enum value."
)
def
decorator
(
cls
):
path
=
class_path
or
f
"
{
cls
.
__module__
}
.
{
cls
.
__qualname__
}
"
BACKEND_MAP
[
backend
]
=
path
return
cls
return
decorator
def
backend_to_class_str
(
backend
:
_Backend
)
->
str
:
"""Get the backend class string
Args:
backend: The backend enum value
Returns:
The backend class string
"""
return
BACKEND_MAP
[
backend
]
def
backend_to_class
(
backend
:
_Backend
)
->
type
:
"""Get the backend class.
Args:
backend: The backend enum value
Returns:
The backend class
"""
backend_class_name
=
backend_to_class_str
(
backend
)
return
resolve_obj_by_qualname
(
backend_class_name
)
def
backend_name_to_enum
(
backend_name
:
str
)
->
Optional
[
_Backend
]:
"""
Convert a string backend name to a _Backend enum value.
Returns:
_Backend: enum value if backend_name is a valid in-tree type
None: otherwise it's an invalid in-tree type or an out-of-tree platform
is loaded.
"""
assert
backend_name
is
not
None
return
_Backend
[
backend_name
]
if
backend_name
in
_Backend
.
__members__
else
None
vllm/attention/layer.py
View file @
76879cc1
...
@@ -11,8 +11,8 @@ import torch.nn.functional as F
...
@@ -11,8 +11,8 @@ import torch.nn.functional as F
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionType
from
vllm.attention
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
,
backend_name_to_enum
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.utils.kv_sharing_utils
import
validate_kv_sharing_target
from
vllm.attention.utils.kv_sharing_utils
import
validate_kv_sharing_target
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.distributed.kv_transfer
import
(
from
vllm.distributed.kv_transfer
import
(
...
...
vllm/attention/selector.py
View file @
76879cc1
...
@@ -12,7 +12,7 @@ import torch
...
@@ -12,7 +12,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
,
backend_name_to_enum
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
resolve_obj_by_qualname
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
resolve_obj_by_qualname
...
@@ -20,19 +20,6 @@ from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
...
@@ -20,19 +20,6 @@ from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
backend_name_to_enum
(
backend_name
:
str
)
->
Optional
[
_Backend
]:
"""
Convert a string backend name to a _Backend enum value.
Returns:
* _Backend: enum value if backend_name is a valid in-tree type
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
loaded.
"""
assert
backend_name
is
not
None
return
_Backend
[
backend_name
]
if
backend_name
in
_Backend
.
__members__
else
None
def
get_env_variable_attn_backend
()
->
Optional
[
_Backend
]:
def
get_env_variable_attn_backend
()
->
Optional
[
_Backend
]:
"""
"""
Get the backend override specified by the vLLM attention
Get the backend override specified by the vLLM attention
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
76879cc1
...
@@ -21,8 +21,8 @@ import torch
...
@@ -21,8 +21,8 @@ import torch
import
zmq
import
zmq
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.backends.registry
import
_Backend
,
backend_name_to_enum
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
CopyBlocksOp
,
CopyBlocksOp
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment