Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a1f011d0
"benchmarking/vscode:/vscode.git/clone" did not exist on "8c507d92c0950305d376b19137b5d8cccccea457"
Unverified
Commit
a1f011d0
authored
Aug 22, 2025
by
Mick
Committed by
GitHub
Aug 22, 2025
Browse files
minor: determine mm attn backend based on platforms (#9303)
parent
9ec314c6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
13 deletions
+40
-13
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+40
-13
No files found.
python/sglang/srt/layers/attention/vision.py
View file @
a1f011d0
...
@@ -12,7 +12,12 @@ import torch.nn.functional as F
...
@@ -12,7 +12,12 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
einops
import
rearrange
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
,
get_attention_tp_size
from
sglang.srt.utils
import
is_cuda
,
print_info_once
from
sglang.srt.utils
import
(
get_device_capability
,
is_blackwell
,
is_cuda
,
print_info_once
,
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -20,7 +25,6 @@ if _is_cuda:
...
@@ -20,7 +25,6 @@ if _is_cuda:
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
parallel_state
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
...
@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
...
@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
self
.
dummy_dim
,
eps
=
layer_norm_eps
,
var_hidden_size
=
embed_dim
self
.
dummy_dim
,
eps
=
layer_norm_eps
,
var_hidden_size
=
embed_dim
)
)
# priority: server_args > passed qkv_backend > sdpa
# Select attention backend via a unified method
if
global_server_args_dict
[
"mm_attention_backend"
]
is
None
:
_passed_backend
=
qkv_backend
if
qkv_backend
is
None
:
qkv_backend
=
self
.
_determine_attention_backend
(
_passed_backend
)
if
is_cuda
():
if
(
# Double prefill throughput by setting attn backend to Triton on CUDA
global_server_args_dict
[
"mm_attention_backend"
]
is
None
qkv_backend
=
"triton_attn"
and
_passed_backend
is
None
else
:
):
qkv_backend
=
"sdpa"
print_info_once
(
f
"Multimodal attention backend not set. Use
{
qkv_backend
}
."
)
print_info_once
(
f
"Multimodal attention backend not set. Use
{
qkv_backend
}
."
)
else
:
qkv_backend
=
global_server_args_dict
[
"mm_attention_backend"
]
print_info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
print_info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
self
.
customized_position_embedding_applier
=
(
self
.
customized_position_embedding_applier
=
(
...
@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
...
@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
prefix
=
add_prefix
(
"proj"
,
prefix
),
prefix
=
add_prefix
(
"proj"
,
prefix
),
)
)
def
_determine_attention_backend
(
self
,
passed_backend
:
Optional
[
str
])
->
str
:
"""Decide the multimodal attention backend string.
Priority: server args override > constructor arg > platform default.
Platform defaults:
- CUDA: "triton_attn"
- Non-CUDA: "sdpa"
"""
override_backend
=
global_server_args_dict
[
"mm_attention_backend"
]
if
override_backend
is
not
None
:
backend
=
override_backend
elif
passed_backend
is
not
None
:
backend
=
passed_backend
elif
is_cuda
():
major
,
minor
=
get_device_capability
()
if
major
==
9
:
backend
=
"fa3"
else
:
backend
=
"triton_attn"
else
:
backend
=
"sdpa"
if
backend
==
"fa3"
and
is_blackwell
():
raise
ValueError
(
"The 'fa3' backend is not supported on Blackwell GPUs"
)
return
backend
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
):
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
):
"""apply qk norm for internvl vit attn"""
"""apply qk norm for internvl vit attn"""
q
=
q
.
flatten
(
1
,
2
)
q
=
q
.
flatten
(
1
,
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment