Unverified Commit 07a606aa authored by Huamin Li's avatar Huamin Li Committed by GitHub
Browse files

[CI Failure] Fix backend selection for encoder-only models (#28534)


Signed-off-by: default avatarHuamin Li <3ericli@gmail.com>
parent a7791eac
......@@ -142,6 +142,17 @@ class AttentionBackend(ABC):
def is_sparse(cls) -> bool:
return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention import AttentionType
return attn_type == AttentionType.DECODER
@classmethod
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
return True
......@@ -171,6 +182,7 @@ class AttentionBackend(ABC):
has_sink: bool,
use_sparse: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
......@@ -195,6 +207,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("non-sparse not supported")
if not cls.supports_compute_capability(device_capability):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
......
......@@ -291,6 +291,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size,
use_mla=False,
has_sink=self.has_sink,
attn_type=attn_type,
)
else:
self.attn_backend = attn_backend
......
......@@ -74,7 +74,11 @@ class EncoderOnlyAttention(Attention):
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
)
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
......
......@@ -76,6 +76,7 @@ def get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
......@@ -94,6 +95,7 @@ def get_attn_backend(
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
attn_type=attn_type,
)
......@@ -106,6 +108,7 @@ def _cached_get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
......@@ -159,6 +162,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
attn_type,
)
else:
attention_cls = current_platform.get_attn_backend_cls(
......@@ -170,6 +174,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
attn_type,
)
if not attention_cls:
raise ValueError(
......
......@@ -134,6 +134,7 @@ class CpuPlatform(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
......
......@@ -298,6 +298,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
......@@ -318,6 +319,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
)
except ImportError:
invalid_reasons_i = ["ImportError"]
......@@ -339,7 +341,13 @@ class CudaPlatformBase(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention import AttentionType
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability()
assert device_capability is not None
......@@ -356,6 +364,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
)
except ImportError:
invalid_reasons = ["ImportError"]
......@@ -379,6 +388,7 @@ class CudaPlatformBase(Platform):
has_sink,
use_sparse,
device_capability,
attn_type,
)
reasons_str = (
"{"
......
......@@ -222,6 +222,7 @@ class Platform:
use_mla: bool,
has_sink: bool,
use_sparse: bool,
attn_type: str | None = None,
) -> str:
"""Get the attention backend class of a device."""
return ""
......
......@@ -216,6 +216,7 @@ class RocmPlatform(Platform):
use_mla,
has_sink,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
......
......@@ -61,6 +61,7 @@ class TpuPlatform(Platform):
use_mla: bool,
has_sink,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
......
......@@ -51,6 +51,7 @@ class XPUPlatform(Platform):
use_mla: bool,
has_sink: bool,
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
......
......@@ -48,6 +48,17 @@ class CPUAttentionBackend(AttentionBackend):
def get_name() -> str:
return "CPU_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
)
@staticmethod
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl
......
......@@ -66,6 +66,18 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLASH_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
from vllm.attention import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
......
......@@ -84,6 +84,13 @@ class FlexAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLEX_ATTENTION"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
......
......@@ -40,14 +40,14 @@ logger = init_logger(__name__)
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512
- **First 512 bytes:** The "quantized NoPE" part, containing 512
`float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
The first `float32` is the scale for the first 128 `float8_e4m3` values,
the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
part is not quantized for accuracy.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment