Unverified Commit 8c760b6a authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

[ROCm] Refactor ROCm attention backend selection logic (#35246)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent 3ee68590
...@@ -211,6 +211,6 @@ configuration. ...@@ -211,6 +211,6 @@ configuration.
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | bf16 | `auto` | Any | 576 | ❌ | | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
...@@ -103,21 +103,20 @@ def test_backend_selection( ...@@ -103,21 +103,20 @@ def test_backend_selection(
if name == "TRITON_MLA" and block_size == 1: if name == "TRITON_MLA" and block_size == 1:
# TRITON_MLA doesn't support block_size == 1 # TRITON_MLA doesn't support block_size == 1
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError):
get_attn_backend( get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, block_size, use_mla=use_mla
) )
assert f"The selected backend, {name}" in str(exc_info.value)
else: else:
# Valid backend-block_size combination # Valid backend-block_size combination
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, block_size, use_mla=use_mla
) )
expected = name expected = name
assert backend.get_name() == expected assert backend.get_name() == expected
else: else:
backend = get_attn_backend( backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla 32, torch.float16, None, block_size, use_mla=use_mla
) )
expected = "ROCM_ATTN" expected = "ROCM_ATTN"
assert backend.get_name() == expected assert backend.get_name() == expected
......
...@@ -306,6 +306,52 @@ def flash_attn_triton_available() -> bool: ...@@ -306,6 +306,52 @@ def flash_attn_triton_available() -> bool:
return False return False
def _get_backend_priorities(
use_mla: bool,
use_sparse: bool,
) -> list[AttentionBackendEnum]:
from vllm._aiter_ops import rocm_aiter_ops
if use_sparse:
return [AttentionBackendEnum.ROCM_AITER_MLA_SPARSE]
if use_mla:
if rocm_aiter_ops.is_mla_enabled():
return [
AttentionBackendEnum.ROCM_AITER_MLA,
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.ROCM_AITER_TRITON_MLA,
]
else:
return [
AttentionBackendEnum.TRITON_MLA,
]
backends = []
# Priority 1: Check for AITER Unified Attention (must check before MHA)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)
# Priority 2: Check for AITER MHA (Flash Attention)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA:
backends.append(AttentionBackendEnum.ROCM_AITER_FA)
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.use_prefill_decode_attention
):
backends.append(AttentionBackendEnum.ROCM_ATTN)
# Default: Triton Unified Attention
backends.append(AttentionBackendEnum.TRITON_ATTN)
return backends
class RocmPlatform(Platform): class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM _enum = PlatformEnum.ROCM
device_name: str = "rocm" device_name: str = "rocm"
...@@ -349,6 +395,39 @@ class RocmPlatform(Platform): ...@@ -349,6 +395,39 @@ class RocmPlatform(Platform):
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
import vllm._rocm_C # noqa: F401 import vllm._rocm_C # noqa: F401
@classmethod
def get_valid_backends(
cls,
device_capability: DeviceCapability,
attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
]:
valid_backends_priorities = []
invalid_reasons = {}
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla,
attn_selector_config.use_sparse,
)
for priority, backend in enumerate(backend_priorities):
try:
backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration(
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons_i = ["ImportError"]
if invalid_reasons_i:
invalid_reasons[backend] = invalid_reasons_i
else:
valid_backends_priorities.append((backend, priority))
return valid_backends_priorities, invalid_reasons
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
...@@ -356,118 +435,71 @@ class RocmPlatform(Platform): ...@@ -356,118 +435,71 @@ class RocmPlatform(Platform):
attn_selector_config: "AttentionSelectorConfig", attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None, num_heads: int | None = None,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops device_capability = cls.get_device_capability()
assert device_capability is not None
block_size = attn_selector_config.block_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse: # First try checking just the selected backend, if there is one.
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"): if selected_backend is not None:
try:
backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration(
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons = ["ImportError"]
if invalid_reasons:
raise ValueError( raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {invalid_reasons}"
) )
assert block_size == 1, ( else:
"Sparse MLA backend on ROCm only supports block size 1 for now." logger.info("Using %s backend.", selected_backend)
return selected_backend.get_path()
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
) )
logger.info_once("Using Sparse MLA backend.") reasons_str = (
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() "{"
+ ", ".join(
if attn_selector_config.use_mla: f"{backend.name}: [{', '.join(reasons)}]"
if selected_backend is None: for backend, reasons in invalid_reasons.items()
selected_backend = (
AttentionBackendEnum.ROCM_AITER_MLA
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else AttentionBackendEnum.TRITON_MLA
) )
if selected_backend == AttentionBackendEnum.TRITON_MLA: + "}"
if block_size != 1:
logger.info_once("Using Triton MLA backend.")
return AttentionBackendEnum.TRITON_MLA.get_path()
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
) )
if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA: config_str = attn_selector_config.__repr__()
logger.info("Using AITER MLA backend.") logger.debug_once(
return AttentionBackendEnum.ROCM_AITER_MLA.get_path() f"Some attention backends are not valid for {cls.device_name} with "
if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA: f"{config_str}. Reasons: {reasons_str}."
logger.info("Using AITER TRITON MLA backend.")
return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend."
) )
if len(valid_backends_priorities) == 0:
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.")
return AttentionBackendEnum.FLEX_ATTENTION.get_path()
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info("Using Triton Attention backend.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
if selected_backend == AttentionBackendEnum.ROCM_ATTN:
logger.info("Using Rocm Attention backend.")
return AttentionBackendEnum.ROCM_ATTN.get_path()
if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
if on_gfx9():
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
else:
raise ValueError( raise ValueError(
f"The selected backend, {selected_backend.name}, " f"No valid attention backend found for {cls.device_name} "
"is only supported on gfx9 architectures." f"with {config_str}. Reasons: {reasons_str}."
) )
if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: # We have found some valid backends. Select the one with the
logger.info("Using Aiter Unified Attention backend.") # highest priority.
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() sorted_indices = sorted(
range(len(valid_backends_priorities)),
# Handle automatic backend selection based on environment variables key=lambda i: valid_backends_priorities[i][1],
if selected_backend is None: )
# Priority 1: Check for AITER Unified Attention (must check before MHA) selected_index = sorted_indices[0]
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: selected_backend = valid_backends_priorities[selected_index][0]
logger.info("Using Aiter Unified Attention backend.") logger.info_once(
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() "Using %s attention backend out of potential backends: %s.",
selected_backend.name,
# Priority 2: Check for AITER MHA (Flash Attention) "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1) scope="local",
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.use_prefill_decode_attention
):
logger.info("Using Rocm Attention backend.")
return AttentionBackendEnum.ROCM_ATTN.get_path()
# Priority 4: Check for AITER enabled without specific flags
# This defaults to AITER FA only if MHA is not explicitly disabled
if (
envs.VLLM_ROCM_USE_AITER
and on_gfx9()
and envs.VLLM_ROCM_USE_AITER_MHA is not False
):
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
# Default: Triton Unified Attention
logger.info("Using Triton Attention backend.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
raise RuntimeError(
f"Attention backend {selected_backend.name} is not supported on "
"ROCm. Note that V0 attention backends have been removed."
) )
return selected_backend.get_path()
@classmethod @classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [ return [
......
...@@ -77,6 +77,7 @@ def fetch_id_to_ragged_triton( ...@@ -77,6 +77,7 @@ def fetch_id_to_ragged_triton(
class ROCMAiterMLASparseBackend(AttentionBackend): class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -104,14 +105,23 @@ class ROCMAiterMLASparseBackend(AttentionBackend): ...@@ -104,14 +105,23 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return (num_blocks, block_size, head_size) return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [576] return [576]
@classmethod
def is_mla(cls) -> bool:
return True
@classmethod
def is_sparse(cls) -> bool:
return True
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
# The only supported block_size is 1
return block_size is None or block_size == 1
@dataclass @dataclass
class ROCMAiterMLASparseMetadata(AttentionMetadata): class ROCMAiterMLASparseMetadata(AttentionMetadata):
......
...@@ -45,6 +45,11 @@ class TritonMLABackend(MLACommonBackend): ...@@ -45,6 +45,11 @@ class TritonMLABackend(MLACommonBackend):
def supports_compute_capability(cls, capability: DeviceCapability) -> bool: def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True return True
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
# The only unsupported block_size is 1
return block_size is None or block_size != 1
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True can_return_lse_for_decode: bool = True
......
...@@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config ...@@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
...@@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend): ...@@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend):
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size) return (2, num_blocks, block_size, num_kv_heads, head_size)
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
from vllm.platforms.rocm import on_mi3xx
# DeviceCapability is currently created using torch.cuda.get_device_capability()
# which is known to be buggy on rocm systems. on_mi3xx uses amd-smi which is
# more reliable.
return on_mi3xx()
class AiterFlashAttentionImpl(AttentionImpl): class AiterFlashAttentionImpl(AttentionImpl):
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment