Unverified Commit c8547ecd authored by Morpheus Guo's avatar Morpheus Guo Committed by GitHub
Browse files

Enable Aiter Attention for VL model (#12699)


Co-authored-by: default avataryuechguo <yuechguo@amd.com>
parent 7bc1dae0
...@@ -13,15 +13,18 @@ from einops import rearrange ...@@ -13,15 +13,18 @@ from einops import rearrange
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var,
get_device_capability, get_device_capability,
is_blackwell, is_blackwell,
is_cuda, is_cuda,
is_hip,
is_npu, is_npu,
print_info_once, print_info_once,
) )
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
_is_hip = is_hip()
if _is_cuda: if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func from sgl_kernel.flash_attn import flash_attn_varlen_func
...@@ -52,6 +55,10 @@ ROTARY_EMBED_CLASSES = { ...@@ -52,6 +55,10 @@ ROTARY_EMBED_CLASSES = {
"normal": apply_rotary_pos_emb, "normal": apply_rotary_pos_emb,
} }
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
from aiter import flash_attn_varlen_func as aiter_flash_attn_varlen_func
@dataclasses.dataclass @dataclasses.dataclass
class SingletonCache: class SingletonCache:
...@@ -336,6 +343,49 @@ class VisionFlash3Attention(nn.Module): ...@@ -336,6 +343,49 @@ class VisionFlash3Attention(nn.Module):
return output return output
class VisionAiterAttention(nn.Module):
def __init__(
self,
**kwargs,
):
if not _use_aiter:
raise Exception("aiter_attn is only available for AMD")
super().__init__()
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
bsz: int,
seq_len: int,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
elif isinstance(cu_seqlens, SingletonCache):
if cu_seqlens.empty():
cu_seqlens.set_data(
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
)
cu_seqlens = cu_seqlens.get_data()
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item()
return aiter_flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
)
class VisionAscendAttention(nn.Module): class VisionAscendAttention(nn.Module):
def __init__( def __init__(
...@@ -393,6 +443,7 @@ QKV_BACKEND_IMPL = { ...@@ -393,6 +443,7 @@ QKV_BACKEND_IMPL = {
"sdpa": VisionSdpaAttention, "sdpa": VisionSdpaAttention,
"fa3": VisionFlash3Attention, "fa3": VisionFlash3Attention,
"ascend_attn": VisionAscendAttention, "ascend_attn": VisionAscendAttention,
"aiter_attn": VisionAiterAttention,
} }
...@@ -539,6 +590,11 @@ class VisionAttention(nn.Module): ...@@ -539,6 +590,11 @@ class VisionAttention(nn.Module):
backend = "fa3" backend = "fa3"
else: else:
backend = "triton_attn" backend = "triton_attn"
elif _use_aiter:
if get_device_capability() < (9, 4):
backend = "triton_attn"
else:
backend = "aiter_attn"
else: else:
backend = "sdpa" backend = "sdpa"
if backend == "fa3" and is_blackwell(): if backend == "fa3" and is_blackwell():
......
...@@ -2644,7 +2644,7 @@ class ServerArgs: ...@@ -2644,7 +2644,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--mm-attention-backend", "--mm-attention-backend",
type=str, type=str,
choices=["sdpa", "fa3", "triton_attn", "ascend_attn"], choices=["sdpa", "fa3", "triton_attn", "ascend_attn", "aiter_attn"],
default=ServerArgs.mm_attention_backend, default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.", help="Set multimodal attention backend.",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment