Unverified Commit 07a606aa authored by Huamin Li's avatar Huamin Li Committed by GitHub
Browse files

[CI Failure] Fix backend selection for encoder-only models (#28534)


Signed-off-by: default avatarHuamin Li <3ericli@gmail.com>
parent a7791eac
...@@ -142,6 +142,17 @@ class AttentionBackend(ABC): ...@@ -142,6 +142,17 @@ class AttentionBackend(ABC):
def is_sparse(cls) -> bool: def is_sparse(cls) -> bool:
return False return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention import AttentionType
return attn_type == AttentionType.DECODER
@classmethod @classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True return True
...@@ -171,6 +182,7 @@ class AttentionBackend(ABC): ...@@ -171,6 +182,7 @@ class AttentionBackend(ABC):
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
device_capability: "DeviceCapability", device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]: ) -> list[str]:
invalid_reasons = [] invalid_reasons = []
if not cls.supports_head_size(head_size): if not cls.supports_head_size(head_size):
...@@ -195,6 +207,8 @@ class AttentionBackend(ABC): ...@@ -195,6 +207,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("non-sparse not supported") invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability): if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported") invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
combination_reason = cls.supports_combination( combination_reason = cls.supports_combination(
head_size, head_size,
dtype, dtype,
......
...@@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size, block_size,
use_mla=False, use_mla=False,
has_sink=self.has_sink, has_sink=self.has_sink,
attn_type=attn_type,
) )
else: else:
self.attn_backend = attn_backend self.attn_backend = attn_backend
......
...@@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention): ...@@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention):
block_size = 16 block_size = 16
underlying_attn_backend = get_attn_backend( underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
) )
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend) attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
......
...@@ -76,6 +76,7 @@ def get_attn_backend( ...@@ -76,6 +76,7 @@ def get_attn_backend(
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
...@@ -94,6 +95,7 @@ def get_attn_backend( ...@@ -94,6 +95,7 @@ def get_attn_backend(
use_mla=use_mla, use_mla=use_mla,
has_sink=has_sink, has_sink=has_sink,
use_sparse=use_sparse, use_sparse=use_sparse,
attn_type=attn_type,
) )
...@@ -106,6 +108,7 @@ def _cached_get_attn_backend( ...@@ -106,6 +108,7 @@ def _cached_get_attn_backend(
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
# Check whether a particular choice of backend was # Check whether a particular choice of backend was
# previously forced. # previously forced.
...@@ -159,6 +162,7 @@ def _cached_get_attn_backend( ...@@ -159,6 +162,7 @@ def _cached_get_attn_backend(
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type,
) )
else: else:
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
...@@ -170,6 +174,7 @@ def _cached_get_attn_backend( ...@@ -170,6 +174,7 @@ def _cached_get_attn_backend(
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type,
) )
if not attention_cls: if not attention_cls:
raise ValueError( raise ValueError(
......
...@@ -134,6 +134,7 @@ class CpuPlatform(Platform): ...@@ -134,6 +134,7 @@ class CpuPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
......
...@@ -298,6 +298,7 @@ class CudaPlatformBase(Platform): ...@@ -298,6 +298,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) -> tuple[ ) -> tuple[
list[tuple["AttentionBackendEnum", int]], list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]], dict["AttentionBackendEnum", list[str]],
...@@ -318,6 +319,7 @@ class CudaPlatformBase(Platform): ...@@ -318,6 +319,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons_i = ["ImportError"] invalid_reasons_i = ["ImportError"]
...@@ -339,7 +341,13 @@ class CudaPlatformBase(Platform): ...@@ -339,7 +341,13 @@ class CudaPlatformBase(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention import AttentionType
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability() device_capability = cls.get_device_capability()
assert device_capability is not None assert device_capability is not None
...@@ -356,6 +364,7 @@ class CudaPlatformBase(Platform): ...@@ -356,6 +364,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) )
except ImportError: except ImportError:
invalid_reasons = ["ImportError"] invalid_reasons = ["ImportError"]
...@@ -379,6 +388,7 @@ class CudaPlatformBase(Platform): ...@@ -379,6 +388,7 @@ class CudaPlatformBase(Platform):
has_sink, has_sink,
use_sparse, use_sparse,
device_capability, device_capability,
attn_type,
) )
reasons_str = ( reasons_str = (
"{" "{"
......
...@@ -222,6 +222,7 @@ class Platform: ...@@ -222,6 +222,7 @@ class Platform:
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None,
) -> str: ) -> str:
"""Get the attention backend class of a device.""" """Get the attention backend class of a device."""
return "" return ""
......
...@@ -216,6 +216,7 @@ class RocmPlatform(Platform): ...@@ -216,6 +216,7 @@ class RocmPlatform(Platform):
use_mla, use_mla,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
......
...@@ -61,6 +61,7 @@ class TpuPlatform(Platform): ...@@ -61,6 +61,7 @@ class TpuPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink, has_sink,
use_sparse, use_sparse,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
......
...@@ -51,6 +51,7 @@ class XPUPlatform(Platform): ...@@ -51,6 +51,7 @@ class XPUPlatform(Platform):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse, use_sparse,
attn_type: str | None = None,
) -> str: ) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout from vllm.v1.attention.backends.utils import set_kv_cache_layout
......
...@@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend): ...@@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "CPU_ATTN" return "CPU_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
)
@staticmethod @staticmethod
def get_impl_cls() -> type["CPUAttentionBackendImpl"]: def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl return CPUAttentionBackendImpl
......
...@@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN" return "FLASH_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod @staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]: def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl return FlashAttentionImpl
......
...@@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend): ...@@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLEX_ATTENTION" return "FLEX_ATTENTION"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod @staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]: def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl return FlexAttentionImpl
......
...@@ -40,14 +40,14 @@ logger = init_logger(__name__) ...@@ -40,14 +40,14 @@ logger = init_logger(__name__)
""" """
NOTE: FlashMLA Sparse uses an fp8 cache with the following format NOTE: FlashMLA Sparse uses an fp8 cache with the following format
In the "FP8 with scale" format, each token's KV cache is 656 Bytes, In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
structured as: structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512 - **First 512 bytes:** The "quantized NoPE" part, containing 512
`float8_e4m3` values. `float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values. - **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values, The first `float32` is the scale for the first 128 `float8_e4m3` values,
the second for the next 128, and so on. the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This - **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy. part is not quantized for accuracy.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment