Unverified Commit 07a606aa authored by Huamin Li's avatar Huamin Li Committed by GitHub
Browse files

[CI Failure] Fix backend selection for encoder-only models (#28534)


Signed-off-by: default avatarHuamin Li <3ericli@gmail.com>
parent a7791eac
...@@ -142,6 +142,17 @@ class AttentionBackend(ABC): ...@@ -142,6 +142,17 @@ class AttentionBackend(ABC):
def is_sparse(cls) -> bool: def is_sparse(cls) -> bool:
return False return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention import AttentionType
return attn_type == AttentionType.DECODER
@classmethod @classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True return True
...@@ -171,6 +182,7 @@ class AttentionBackend(ABC): ...@@ -171,6 +182,7 @@ class AttentionBackend(ABC):
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
device_capability: "DeviceCapability", device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]: ) -> list[str]:
invalid_reasons = [] invalid_reasons = []
if not cls.supports_head_size(head_size): if not cls.supports_head_size(head_size):
...@@ -195,6 +207,8 @@ class AttentionBackend(ABC): ...@@ -195,6 +207,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("non-sparse not supported") invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability): if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported") invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
combination_reason = cls.supports_combination( combination_reason = cls.supports_combination(
head_size, head_size,
dtype, dtype,
......
...@@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size, block_size,
use_mla=False, use_mla=False,
has_sink=self.has_sink, has_sink=self.has_sink,
attn_type=attn_type,
) )
else: else:
self.attn_backend = attn_backend self.attn_backend = attn_backend
......
...@@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention): ...@@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention):
block_size = 16 block_size = 16
underlying_attn_backend = get_attn_backend( underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
) )
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend) attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
......
...@@ -76,6 +76,7 @@ def get_attn_backend( ...@@ -76,6 +76,7 @@ def get_attn_backend(
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
...@@ -94,6 +95,7 @@ def get_attn_backend( ...@@ -94,6 +95,7 @@ def get_attn_backend(
use_mla=use_mla, use_mla=use_mla,
has_sink=has_sink, has_sink=has_sink,
use_sparse=use_sparse, use_sparse=use_sparse,
attn_type=attn_type,
) )
...@@ -106,6 +108,7 @@ def _cached_get_attn_backend( ...@@ -106,6 +108,7 @@ def _cached_get_attn_backend(
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
# Check whether a particular choice of backend was # Check whether a particular choice of backend was
# previously forced. # previously forced.
...@@ -159,6 +162,7 @@ def _cached_get_attn_backend( ...@@ -159,6 +162,7 @@ def _cached_get_attn_backend(
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type,
) )
else: else:
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
...@@ -170,6 +174,7 @@ def _cached_get_attn_backend( ...@@ -170,6 +174,7 @@ def _cached_get_attn_backend(
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type,
) )
if not attention_cls: if not attention_cls:
raise ValueError( raise ValueError(
......
...@@ -134,6 +134,7 @@ class CpuPlatform(Platform): ...@@ -134,6 +134,7 @@ class CpuPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
......
...@@ -298,6 +298,7 @@ class CudaPlatformBase(Platform): ...@@ -298,6 +298,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) -> tuple[ ) -> tuple[
list[tuple["AttentionBackendEnum", int]], list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]], dict["AttentionBackendEnum", list[str]],
...@@ -318,6 +319,7 @@ class CudaPlatformBase(Platform): ...@@ -318,6 +319,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons_i = ["ImportError"] invalid_reasons_i = ["ImportError"]
...@@ -339,7 +341,13 @@ class CudaPlatformBase(Platform): ...@@ -339,7 +341,13 @@ class CudaPlatformBase(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention import AttentionType
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability() device_capability = cls.get_device_capability()
assert device_capability is not None assert device_capability is not None
...@@ -356,6 +364,7 @@ class CudaPlatformBase(Platform): ...@@ -356,6 +364,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons = ["ImportError"] invalid_reasons = ["ImportError"]
...@@ -379,6 +388,7 @@ class CudaPlatformBase(Platform): ...@@ -379,6 +388,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) )
reasons_str = ( reasons_str = (
"{" "{"
......
...@@ -222,6 +222,7 @@ class Platform: ...@@ -222,6 +222,7 @@ class Platform:
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None,
) -> str: ) -> str:
"""Get the attention backend class of a device.""" """Get the attention backend class of a device."""
return "" return ""
......
...@@ -216,6 +216,7 @@ class RocmPlatform(Platform): ...@@ -216,6 +216,7 @@ class RocmPlatform(Platform):
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
......
...@@ -61,6 +61,7 @@ class TpuPlatform(Platform): ...@@ -61,6 +61,7 @@ class TpuPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
......
...@@ -51,6 +51,7 @@ class XPUPlatform(Platform): ...@@ -51,6 +51,7 @@ class XPUPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse, use_sparse,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.attention.backends.utils import set_kv_cache_layout
......
...@@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend): ...@@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "CPU_ATTN" return "CPU_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
)
@staticmethod @staticmethod
def get_impl_cls() -> type["CPUAttentionBackendImpl"]: def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl return CPUAttentionBackendImpl
......
...@@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN" return "FLASH_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod @staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]: def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl return FlashAttentionImpl
......
...@@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend): ...@@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLEX_ATTENTION" return "FLEX_ATTENTION"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod @staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]: def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl return FlexAttentionImpl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment