Unverified Commit c8547ecd authored by Morpheus Guo's avatar Morpheus Guo Committed by GitHub
Browse files

Enable Aiter Attention for VL model (#12699)


Co-authored-by: default avataryuechguo <yuechguo@amd.com>
parent 7bc1dae0
......@@ -13,15 +13,18 @@ from einops import rearrange
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
is_blackwell,
is_cuda,
is_hip,
is_npu,
print_info_once,
)
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func
......@@ -52,6 +55,10 @@ ROTARY_EMBED_CLASSES = {
"normal": apply_rotary_pos_emb,
}
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
from aiter import flash_attn_varlen_func as aiter_flash_attn_varlen_func
@dataclasses.dataclass
class SingletonCache:
......@@ -336,6 +343,49 @@ class VisionFlash3Attention(nn.Module):
return output
class VisionAiterAttention(nn.Module):
def __init__(
self,
**kwargs,
):
if not _use_aiter:
raise Exception("aiter_attn is only available for AMD")
super().__init__()
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
bsz: int,
seq_len: int,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
elif isinstance(cu_seqlens, SingletonCache):
if cu_seqlens.empty():
cu_seqlens.set_data(
_get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
)
cu_seqlens = cu_seqlens.get_data()
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item()
return aiter_flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
)
class VisionAscendAttention(nn.Module):
def __init__(
......@@ -393,6 +443,7 @@ QKV_BACKEND_IMPL = {
"sdpa": VisionSdpaAttention,
"fa3": VisionFlash3Attention,
"ascend_attn": VisionAscendAttention,
"aiter_attn": VisionAiterAttention,
}
......@@ -539,6 +590,11 @@ class VisionAttention(nn.Module):
backend = "fa3"
else:
backend = "triton_attn"
elif _use_aiter:
if get_device_capability() < (9, 4):
backend = "triton_attn"
else:
backend = "aiter_attn"
else:
backend = "sdpa"
if backend == "fa3" and is_blackwell():
......
......@@ -2644,7 +2644,7 @@ class ServerArgs:
parser.add_argument(
"--mm-attention-backend",
type=str,
choices=["sdpa", "fa3", "triton_attn", "ascend_attn"],
choices=["sdpa", "fa3", "triton_attn", "ascend_attn", "aiter_attn"],
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment