Unverified Commit 07a606aa authored by Huamin Li's avatar Huamin Li Committed by GitHub
Browse files

[CI Failure] Fix backend selection for encoder-only models (#28534)


Signed-off-by: default avatarHuamin Li <3ericli@gmail.com>
parent a7791eac
......@@ -142,6 +142,17 @@ class AttentionBackend(ABC):
def is_sparse(cls) -> bool:
return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention import AttentionType
return attn_type == AttentionType.DECODER
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True
......@@ -171,6 +182,7 @@ class AttentionBackend(ABC):
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
......@@ -195,6 +207,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
......
......@@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size,
use_mla=False,
has_sink=self.has_sink,
attn_type=attn_type,
)
else:
self.attn_backend = attn_backend
......
......@@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention):
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
)
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
......
......@@ -76,6 +76,7 @@ def get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
......@@ -94,6 +95,7 @@ def get_attn_backend(
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
attn_type=attn_type,
)
......@@ -106,6 +108,7 @@ def _cached_get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
......@@ -159,6 +162,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
attn_type,
)
else:
attention_cls = current_platform.get_attn_backend_cls(
......@@ -170,6 +174,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
attn_type,
)
if not attention_cls:
raise ValueError(
......
......@@ -134,6 +134,7 @@ class CpuPlatform(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
......
......@@ -298,6 +298,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
......@@ -318,6 +319,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
)
except ImportError:
invalid_reasons_i = ["ImportError"]
......@@ -339,7 +341,13 @@ class CudaPlatformBase(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention import AttentionType
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability()
assert device_capability is not None
......@@ -356,6 +364,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
)
except ImportError:
invalid_reasons = ["ImportError"]
......@@ -379,6 +388,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
)
reasons_str = (
"{"
......
......@@ -222,6 +222,7 @@ class Platform:
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
"""Get the attention backend class of a device."""
return ""
......
......@@ -216,6 +216,7 @@ class RocmPlatform(Platform):
use_mla,
has_sink,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
......
......@@ -61,6 +61,7 @@ class TpuPlatform(Platform):
use_mla: bool,
has_sink,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
......
......@@ -51,6 +51,7 @@ class XPUPlatform(Platform):
use_mla: bool,
has_sink: bool,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
......
......@@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend):
def get_name() -> str:
return "CPU_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
)
@staticmethod
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl
......
......@@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLASH_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
......
......@@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLEX_ATTENTION"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment