Unverified Commit 350ca72c authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][AITER] Fix AITER import regression for explicit backend selection (#33749)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 1fb0495a
...@@ -5,10 +5,17 @@ ...@@ -5,10 +5,17 @@
import pytest import pytest
import torch import torch
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
# Import AITER backend if on ROCm and aiter is available
if current_platform.is_rocm():
from vllm._aiter_ops import is_aiter_found_and_supported
if is_aiter_found_and_supported():
import aiter
from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache
NUM_HEADS = [(4, 4), (8, 2)] NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
...@@ -102,8 +109,11 @@ def test_varlen_with_paged_kv( ...@@ -102,8 +109,11 @@ def test_varlen_with_paged_kv(
num_blocks: int, num_blocks: int,
q_dtype: torch.dtype | None, q_dtype: torch.dtype | None,
) -> None: ) -> None:
if not is_flash_attn_varlen_func_available(): from vllm._aiter_ops import is_aiter_found_and_supported
pytest.skip("flash_attn_varlen_func required to run this test.")
if not is_aiter_found_and_supported():
pytest.skip("aiter package required for this test.")
torch.set_default_device("cuda") torch.set_default_device("cuda")
set_random_seed(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
...@@ -129,6 +139,8 @@ def test_varlen_with_paged_kv( ...@@ -129,6 +139,8 @@ def test_varlen_with_paged_kv(
cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum( cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32 dim=0, dtype=torch.int32
) )
# Save kv_lens as list before converting to tensor
kv_lens_list = kv_lens
kv_lens = torch.tensor(kv_lens, dtype=torch.int32) kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
...@@ -141,33 +153,83 @@ def test_varlen_with_paged_kv( ...@@ -141,33 +153,83 @@ def test_varlen_with_paged_kv(
maybe_quantized_query = query maybe_quantized_query = query
maybe_quantized_key_cache = key_cache maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache maybe_quantized_value_cache = value_cache
k_descale = None k_scale_tensor = None
v_descale = None v_scale_tensor = None
dequant = False
if q_dtype is not None: if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor # QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype) maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype) maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype) maybe_quantized_value_cache = value_cache.to(q_dtype)
dequant = True
scale_shape = (num_seqs, num_kv_heads) scale_shape = (num_seqs, num_kv_heads)
k_descale = torch.ones(scale_shape, dtype=torch.float32)
v_descale = torch.ones(scale_shape, dtype=torch.float32)
torch.ops.vllm.flash_attn_varlen_func( # For per-seq-per-head scales (matching AITER backend expectation)
maybe_quantized_query, k_scale_tensor = torch.ones(scale_shape, dtype=torch.float32)
maybe_quantized_key_cache, v_scale_tensor = torch.ones(scale_shape, dtype=torch.float32)
maybe_quantized_value_cache,
out=output, # Prepare metadata for cp_mha_gather_cache
# token_to_batch: maps each token to its batch index
token_to_batch = torch.zeros(sum(kv_lens_list), dtype=torch.int32)
seq_starts = torch.zeros(num_seqs, dtype=torch.int32)
token_idx = 0
for batch_idx, kv_len in enumerate(kv_lens_list):
token_to_batch[token_idx : token_idx + kv_len] = batch_idx
seq_starts[batch_idx] = 0 # Assuming all sequences start at 0 in their blocks
token_idx += kv_len
# Allocate buffers for gathered KV
total_kv_tokens = sum(kv_lens_list)
gathered_key = torch.empty(
total_kv_tokens, num_kv_heads, head_size, dtype=maybe_quantized_key_cache.dtype
)
gathered_value = torch.empty(
total_kv_tokens,
num_kv_heads,
head_size,
dtype=maybe_quantized_value_cache.dtype,
)
# Gather paged KV cache into contiguous tensors using triton kernel
cp_mha_gather_cache(
key_cache=maybe_quantized_key_cache,
value_cache=maybe_quantized_value_cache,
key=gathered_key,
value=gathered_value,
block_tables=block_tables,
k_scales=k_scale_tensor
if k_scale_tensor is not None
else torch.ones(1, dtype=torch.float32),
v_scales=v_scale_tensor
if v_scale_tensor is not None
else torch.ones(1, dtype=torch.float32),
cu_seqlens_kv=cu_seq_lens,
token_to_batch=token_to_batch,
seq_starts=seq_starts,
dequant=dequant,
kv_cache_layout="NHD",
total_tokens=total_kv_tokens,
)
# Call aiter flash attention with gathered KV
aiter.flash_attn_varlen_func(
q=maybe_quantized_query,
k=gathered_key,
v=gathered_value,
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_seq_lens,
max_seqlen_q=max_query_len, max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len, max_seqlen_k=max_kv_len,
min_seqlen_q=1,
dropout_p=0.0,
softmax_scale=scale, softmax_scale=scale,
alibi_slopes=None, causal=True,
window_size=window_size, window_size=window_size,
block_table=block_tables, alibi_slopes=None,
cu_seqlens_k=cu_seq_lens, return_lse=False,
k_scale=k_descale, out=output,
v_scale=v_descale,
) )
ref_output = ref_paged_attn( ref_output = ref_paged_attn(
...@@ -175,7 +237,7 @@ def test_varlen_with_paged_kv( ...@@ -175,7 +237,7 @@ def test_varlen_with_paged_kv(
key_cache=key_cache, key_cache=key_cache,
value_cache=value_cache, value_cache=value_cache,
query_lens=query_lens, query_lens=query_lens,
kv_lens=kv_lens, kv_lens=kv_lens_list,
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
sliding_window=sliding_window, sliding_window=sliding_window,
...@@ -189,3 +251,8 @@ def test_varlen_with_paged_kv( ...@@ -189,3 +251,8 @@ def test_varlen_with_paged_kv(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}", f"{torch.max(torch.abs(output - ref_output))}",
) )
# Log diff stats for tracking changes
print(f"Max abs diff: {torch.max(torch.abs(output - ref_output))}")
print(f"Mean diff: {torch.mean(torch.abs(output - ref_output))}")
print(f"Min diff: {torch.std(torch.abs(output - ref_output))}")
...@@ -14,7 +14,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( ...@@ -14,7 +14,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_aiter_sparse_attn_indexer_fake, rocm_aiter_sparse_attn_indexer_fake,
) )
_FP8_DTYPE = current_platform.fp8_dtype() # fp8_dtype is not cached.
# on ROCm the fp8_dtype always calls is_fp8_fnuz
# which is a host op, so we cache it once here.
FP8_DTYPE = current_platform.fp8_dtype()
def is_aiter_found() -> bool: def is_aiter_found() -> bool:
...@@ -31,12 +34,22 @@ IS_AITER_FOUND = is_aiter_found() ...@@ -31,12 +34,22 @@ IS_AITER_FOUND = is_aiter_found()
def is_aiter_found_and_supported() -> bool: def is_aiter_found_and_supported() -> bool:
"""Check if AITER is available AND enabled via environment variable. """Check if AITER library is available and platform supports it.
Checks: platform (ROCm), device arch (gfx9), library existence, Checks: platform (ROCm), device arch (gfx9), and library existence.
and VLLM_ROCM_USE_AITER env variable. Does NOT check environment variables - that's handled by rocm_aiter_ops.is_enabled().
This function determines if aiter CAN be used, not if it SHOULD be used.
Separation of concerns:
- This function: Can aiter work on this system? (platform + library availability)
- rocm_aiter_ops.is_enabled(): Should aiter be used by default? (adds env var check)
- Backend selection: Can explicitly request aiter regardless of env var
This allows explicit backend selection via attention_config to work even when
VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery.
""" """
if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER: if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9 from vllm.platforms.rocm import on_gfx9
return on_gfx9() return on_gfx9()
...@@ -58,21 +71,6 @@ def if_aiter_supported(func: Callable) -> Callable: ...@@ -58,21 +71,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return wrapper return wrapper
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if is_aiter_found_and_supported():
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
else:
# Placeholder when AITER is disabled - prevents NameError during module load.
# Note: When AITER is disabled, ops are not registered, so fake implementations
# referencing this variable won't actually be called at runtime.
AITER_FP8_DTYPE = _FP8_DTYPE
def _rocm_aiter_fused_moe_impl( def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -539,7 +537,7 @@ def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl( ...@@ -539,7 +537,7 @@ def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter import aiter as rocm_aiter
assert quant_dtype in [torch.int8, _FP8_DTYPE] assert quant_dtype in [torch.int8, FP8_DTYPE]
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
...@@ -581,7 +579,7 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl( ...@@ -581,7 +579,7 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter import aiter as rocm_aiter
assert quant_dtype in [torch.int8, _FP8_DTYPE] assert quant_dtype in [torch.int8, FP8_DTYPE]
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device) y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
...@@ -630,10 +628,10 @@ def _rocm_aiter_per_token_quant_impl( ...@@ -630,10 +628,10 @@ def _rocm_aiter_per_token_quant_impl(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import dynamic_per_token_scaled_quant from aiter.ops.quant import dynamic_per_token_scaled_quant
assert quant_dtype in [torch.int8, _FP8_DTYPE] assert quant_dtype in [torch.int8, FP8_DTYPE]
out_shape = x.shape out_shape = x.shape
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device) out = torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device)
if scale is None: if scale is None:
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device) scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
dynamic_per_token_scaled_quant( dynamic_per_token_scaled_quant(
...@@ -653,7 +651,7 @@ def _rocm_aiter_per_token_quant_fake( ...@@ -653,7 +651,7 @@ def _rocm_aiter_per_token_quant_fake(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
out_shape = x.shape out_shape = x.shape
return ( return (
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device), torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device),
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device), torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
) )
...@@ -675,7 +673,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( ...@@ -675,7 +673,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
None, None,
None, None,
group_size=group_size, group_size=group_size,
dtype_quant=AITER_FP8_DTYPE, dtype_quant=FP8_DTYPE,
res1=residual, res1=residual,
) )
return ( return (
...@@ -695,7 +693,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( ...@@ -695,7 +693,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
M, N = x.shape M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size) scale_shape = (M, (N + group_size - 1) // group_size)
return ( return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), torch.empty_like(x, dtype=FP8_DTYPE, device=x.device),
torch.empty_like(residual, device=residual.device), torch.empty_like(residual, device=residual.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device), torch.empty(scale_shape, dtype=torch.float32, device=x.device),
) )
...@@ -717,7 +715,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_impl( ...@@ -717,7 +715,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
None, None,
None, None,
group_size=group_size, group_size=group_size,
dtype_quant=AITER_FP8_DTYPE, dtype_quant=FP8_DTYPE,
res1=None, res1=None,
) )
return (x_quant, x_quant_scales) return (x_quant, x_quant_scales)
...@@ -732,7 +730,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_fake( ...@@ -732,7 +730,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
M, N = x.shape M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size) scale_shape = (M, (N + group_size - 1) // group_size)
return ( return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), torch.empty_like(x, dtype=FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device), torch.empty(scale_shape, dtype=torch.float32, device=x.device),
) )
...@@ -745,7 +743,7 @@ def _rocm_aiter_group_fp8_quant_impl( ...@@ -745,7 +743,7 @@ def _rocm_aiter_group_fp8_quant_impl(
from aiter import QuantType, get_hip_quant from aiter import QuantType, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE) return aiter_per1x128_quant(x.contiguous(), quant_dtype=FP8_DTYPE)
def _rocm_aiter_group_fp8_quant_fake( def _rocm_aiter_group_fp8_quant_fake(
...@@ -753,7 +751,7 @@ def _rocm_aiter_group_fp8_quant_fake( ...@@ -753,7 +751,7 @@ def _rocm_aiter_group_fp8_quant_fake(
group_size: int, group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device) x_fp8 = torch.empty((M, N), dtype=FP8_DTYPE, device=x.device)
out_bs = torch.empty( out_bs = torch.empty(
( (
M, M,
...@@ -775,7 +773,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_impl( ...@@ -775,7 +773,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x, x,
activation="silu", activation="silu",
group_size=group_size, group_size=group_size,
dtype_quant=AITER_FP8_DTYPE, dtype_quant=FP8_DTYPE,
) )
...@@ -786,7 +784,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake( ...@@ -786,7 +784,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
M, N = x.shape M, N = x.shape
assert N % 2 == 0 assert N % 2 == 0
N_half = N // 2 N_half = N // 2
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device) x_fp8 = torch.empty((M, N_half), dtype=FP8_DTYPE, device=x.device)
out_bs = torch.empty( out_bs = torch.empty(
( (
M, M,
...@@ -986,7 +984,7 @@ class rocm_aiter_ops: ...@@ -986,7 +984,7 @@ class rocm_aiter_ops:
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
def is_shuffle_kv_cache_enabled(cls) -> bool: def is_shuffle_kv_cache_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED return cls._SHUFFLE_KV_CACHE_ENABLED
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
...@@ -1654,5 +1652,87 @@ class rocm_aiter_ops: ...@@ -1654,5 +1652,87 @@ class rocm_aiter_ops:
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
@staticmethod
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
min_seqlen_q: int | None = None,
dropout_p: float = 0.0,
softmax_scale: float | None = None,
causal: bool = False,
window_size: tuple[int, int] | None = None,
alibi_slopes: torch.Tensor | None = None,
return_lse: bool = False,
out: torch.Tensor | None = None,
):
"""
Flash attention with variable length sequences.
This function is NOT wrapped with @is_aiter_supported decorator
to allow explicit backend selection via attention_config to work
even when VLLM_ROCM_USE_AITER=0.
Note: This performs lazy import of aiter.flash_attn_varlen_func
"""
from aiter import flash_attn_varlen_func
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
min_seqlen_q=min_seqlen_q,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_lse=return_lse,
out=out,
)
@staticmethod
def pa_fwd_asm(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_tables_stride0: int,
K_QScale: torch.Tensor,
V_QScale: torch.Tensor,
out_: torch.Tensor,
):
"""
Paged attention forward pass using assembly kernel.
This function is NOT wrapped with @is_aiter_supported decorator
to allow explicit backend selection via attention_config to work
even when VLLM_ROCM_USE_AITER=0.
Note: This performs lazy import of aiter.pa_fwd_asm
"""
from aiter import pa_fwd_asm
return pa_fwd_asm(
Q=Q,
K=K,
V=V,
block_tables=block_tables,
context_lens=context_lens,
block_tables_stride0=block_tables_stride0,
K_QScale=K_QScale,
V_QScale=V_QScale,
out_=out_,
)
rocm_aiter_ops.register_ops_once() rocm_aiter_ops.register_ops_once()
...@@ -8,6 +8,12 @@ from vllm.platforms import current_platform ...@@ -8,6 +8,12 @@ from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
# Track whether upstream flash-attn is available on ROCm.
# Set during module initialization and never modified afterwards.
# This module-level flag avoids repeated import attempts and ensures
# consistent behavior (similar to IS_AITER_FOUND in _aiter_ops.py).
_ROCM_FLASH_ATTN_AVAILABLE = False
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm._custom_ops import reshape_and_cache_flash from vllm._custom_ops import reshape_and_cache_flash
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
...@@ -26,6 +32,9 @@ elif current_platform.is_xpu(): ...@@ -26,6 +32,9 @@ elif current_platform.is_xpu():
elif current_platform.is_rocm(): elif current_platform.is_rocm():
try: try:
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
# Mark that upstream flash-attn is available on ROCm
_ROCM_FLASH_ATTN_AVAILABLE = True
except ImportError: except ImportError:
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc] def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
...@@ -34,6 +43,15 @@ elif current_platform.is_rocm(): ...@@ -34,6 +43,15 @@ elif current_platform.is_rocm():
"to be installed. Please install flash-attn first." "to be installed. Please install flash-attn first."
) )
# ROCm doesn't use scheduler metadata (FA3 feature), provide stub
def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
return None
# ROCm uses the C++ custom op for reshape_and_cache
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None: def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies # import here to avoid circular dependencies
...@@ -128,4 +146,30 @@ def flash_attn_supports_mla(): ...@@ -128,4 +146,30 @@ def flash_attn_supports_mla():
def is_flash_attn_varlen_func_available() -> bool: def is_flash_attn_varlen_func_available() -> bool:
return current_platform.is_cuda() or current_platform.is_xpu() """Check if flash_attn_varlen_func is available.
This function determines whether the flash_attn_varlen_func imported at module
level is a working implementation or a stub.
Platform-specific sources:
- CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
- XPU: ipex_ops.flash_attn_varlen_func
- ROCm: upstream flash_attn.flash_attn_varlen_func (if available)
Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)
which uses rocm_aiter_ops.flash_attn_varlen_func. The condition to use AITER is
handled separately via _aiter_ops.is_aiter_found_and_supported().
Returns:
bool: True if a working flash_attn_varlen_func implementation is available.
"""
if current_platform.is_cuda() or current_platform.is_xpu():
# CUDA and XPU always have flash_attn_varlen_func available
return True
if current_platform.is_rocm():
# Use the flag set during module import to check if
# upstream flash-attn was successfully imported
return _ROCM_FLASH_ATTN_AVAILABLE
return False
...@@ -34,9 +34,6 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024 ...@@ -34,9 +34,6 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if current_platform.is_rocm(): if current_platform.is_rocm():
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
if rocm_aiter_ops.is_enabled():
import aiter
def block_size(x, head_dim): def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
...@@ -798,7 +795,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -798,7 +795,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
total_tokens=swa_total_tokens, total_tokens=swa_total_tokens,
) )
aiter.flash_attn_varlen_func( rocm_aiter_ops.flash_attn_varlen_func(
q=query, q=query,
k=key_fetched, k=key_fetched,
v=value_fetched, v=value_fetched,
...@@ -848,7 +845,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -848,7 +845,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
v_scale, v_scale,
) )
return return
out, lse = aiter.flash_attn_varlen_func( out, lse = rocm_aiter_ops.flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -895,7 +892,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -895,7 +892,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
total_tokens=total_token_per_batch[chunk_idx], total_tokens=total_token_per_batch[chunk_idx],
) )
suf_out, suf_lse = aiter.flash_attn_varlen_func( suf_out, suf_lse = rocm_aiter_ops.flash_attn_varlen_func(
q=query, q=query,
k=key_fetched, k=key_fetched,
v=value_fetched, v=value_fetched,
...@@ -1053,7 +1050,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1053,7 +1050,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
prefill_key = key[num_decode_tokens + num_extend_tokens :] prefill_key = key[num_decode_tokens + num_extend_tokens :]
prefill_value = value[num_decode_tokens + num_extend_tokens :] prefill_value = value[num_decode_tokens + num_extend_tokens :]
aiter.flash_attn_varlen_func( rocm_aiter_ops.flash_attn_varlen_func(
q=prefill_query, q=prefill_query,
k=prefill_key, k=prefill_key,
v=prefill_value, v=prefill_value,
...@@ -1159,7 +1156,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1159,7 +1156,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
) )
new_key_cache = key_cache.view_as(k_cache_template) new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template) new_value_cache = value_cache.view_as(v_cache_template)
aiter.pa_fwd_asm( rocm_aiter_ops.pa_fwd_asm(
Q=query[:num_decode_tokens], Q=query[:num_decode_tokens],
K=new_key_cache, K=new_key_cache,
V=new_value_cache, V=new_value_cache,
...@@ -1188,6 +1185,10 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1188,6 +1185,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
device=output.device, device=output.device,
) )
# import so that aiter register the op to the namespace of
# torch.ops.aiter
import aiter # noqa: F401
torch.ops.aiter.paged_attention_v1( torch.ops.aiter.paged_attention_v1(
output[:num_decode_tokens], output[:num_decode_tokens],
workspace_buffer, workspace_buffer,
......
...@@ -222,9 +222,13 @@ class SpecDecodeBaseProposer: ...@@ -222,9 +222,13 @@ class SpecDecodeBaseProposer:
RocmAttentionMetadata, RocmAttentionMetadata,
] ]
# ROCM_AITER_FA is an optional backend # ROCM_AITER_FA is an optional backend
from vllm._aiter_ops import rocm_aiter_ops # We check is_enabled() here to avoid importing the backend module during
# auto-discovery when VLLM_ROCM_USE_AITER=0, which would trigger aiter
if rocm_aiter_ops.is_enabled() and find_spec( # import and JIT compilation warnings. Explicit backend selection via
# attention_config still works because the backend module is loaded
# directly when selected, not through this auto-discovery path.
# Check if backend module exists to allow explicit selection
if find_spec(
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False) AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
): ):
from vllm.v1.attention.backends.rocm_aiter_fa import ( from vllm.v1.attention.backends.rocm_aiter_fa import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment