Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c8547ecd
Unverified
Commit
c8547ecd
authored
Nov 06, 2025
by
Morpheus Guo
Committed by
GitHub
Nov 05, 2025
Browse files
Enable Aiter Attention for VL model (#12699)
Co-authored-by:
yuechguo
<
yuechguo@amd.com
>
parent
7bc1dae0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
1 deletion
+57
-1
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+56
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
python/sglang/srt/layers/attention/vision.py
View file @
c8547ecd
...
...
@@ -13,15 +13,18 @@ from einops import rearrange
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_device_capability
,
is_blackwell
,
is_cuda
,
is_hip
,
is_npu
,
print_info_once
,
)
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
_is_hip
=
is_hip
()
if
_is_cuda
:
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
...
...
@@ -52,6 +55,10 @@ ROTARY_EMBED_CLASSES = {
"normal"
:
apply_rotary_pos_emb
,
}
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_use_aiter
:
from
aiter
import
flash_attn_varlen_func
as
aiter_flash_attn_varlen_func
@
dataclasses
.
dataclass
class
SingletonCache
:
...
...
@@ -336,6 +343,49 @@ class VisionFlash3Attention(nn.Module):
return
output
class
VisionAiterAttention
(
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
,
):
if
not
_use_aiter
:
raise
Exception
(
"aiter_attn is only available for AMD"
)
super
().
__init__
()
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
Union
[
SingletonCache
,
torch
.
Tensor
]],
bsz
:
int
,
seq_len
:
int
,
**
kwargs
,
)
->
torch
.
Tensor
:
if
cu_seqlens
is
None
:
cu_seqlens
=
_get_cu_seqlens_for_shape
(
bsz
,
seq_len
,
device
=
q
.
device
)
elif
isinstance
(
cu_seqlens
,
SingletonCache
):
if
cu_seqlens
.
empty
():
cu_seqlens
.
set_data
(
_get_cu_seqlens_for_shape
(
bsz
,
seq_len
,
device
=
q
.
device
)
)
cu_seqlens
=
cu_seqlens
.
get_data
()
cu_seqlens
=
cu_seqlens
.
to
(
dtype
=
torch
.
int32
).
to
(
q
.
device
)
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
max_seqlen
=
seq_lens
.
max
().
item
()
return
aiter_flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens
,
cu_seqlens_k
=
cu_seqlens
,
max_seqlen_q
=
max_seqlen
,
max_seqlen_k
=
max_seqlen
,
)
class
VisionAscendAttention
(
nn
.
Module
):
def
__init__
(
...
...
@@ -393,6 +443,7 @@ QKV_BACKEND_IMPL = {
"sdpa"
:
VisionSdpaAttention
,
"fa3"
:
VisionFlash3Attention
,
"ascend_attn"
:
VisionAscendAttention
,
"aiter_attn"
:
VisionAiterAttention
,
}
...
...
@@ -539,6 +590,11 @@ class VisionAttention(nn.Module):
backend
=
"fa3"
else
:
backend
=
"triton_attn"
elif
_use_aiter
:
if
get_device_capability
()
<
(
9
,
4
):
backend
=
"triton_attn"
else
:
backend
=
"aiter_attn"
else
:
backend
=
"sdpa"
if
backend
==
"fa3"
and
is_blackwell
():
...
...
python/sglang/srt/server_args.py
View file @
c8547ecd
...
...
@@ -2644,7 +2644,7 @@ class ServerArgs:
parser
.
add_argument
(
"--mm-attention-backend"
,
type
=
str
,
choices
=
[
"sdpa"
,
"fa3"
,
"triton_attn"
,
"ascend_attn"
],
choices
=
[
"sdpa"
,
"fa3"
,
"triton_attn"
,
"ascend_attn"
,
"aiter_attn"
],
default
=
ServerArgs
.
mm_attention_backend
,
help
=
"Set multimodal attention backend."
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment