Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
edefab0c
Unverified
Commit
edefab0c
authored
Oct 08, 2025
by
Lifu Huang
Committed by
GitHub
Oct 08, 2025
Browse files
[2/2] Support MHA prefill with FlashAttention 4. (#10937)
Co-authored-by:
Hieu Pham
<
hyhieu@gmail.com
>
parent
97cd38e5
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
34 additions
and
23 deletions
+34
-23
python/pyproject.toml
python/pyproject.toml
+1
-1
python/pyproject_other.toml
python/pyproject_other.toml
+1
-1
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+1
-1
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+0
-3
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+0
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-9
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+28
-7
No files found.
python/pyproject.toml
View file @
edefab0c
...
@@ -53,7 +53,7 @@ dependencies = [
...
@@ -53,7 +53,7 @@ dependencies = [
"scipy"
,
"scipy"
,
"sentencepiece"
,
"sentencepiece"
,
"setproctitle"
,
"setproctitle"
,
"sgl-kernel==0.3.1
4.post1
"
,
"sgl-kernel==0.3.1
5
"
,
"soundfile==0.13.1"
,
"soundfile==0.13.1"
,
"tiktoken"
,
"tiktoken"
,
"timm==1.0.16"
,
"timm==1.0.16"
,
...
...
python/pyproject_other.toml
View file @
edefab0c
...
@@ -65,7 +65,7 @@ tracing = [
...
@@ -65,7 +65,7 @@ tracing = [
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.3.1
4.post1
"
,
"sgl-kernel==0.3.1
5
"
,
"torch==2.8.0"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"torchvision"
,
...
...
python/sglang/srt/entrypoints/engine.py
View file @
edefab0c
...
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if
_is_cuda
and
not
get_bool_env_var
(
"SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"
):
if
_is_cuda
and
not
get_bool_env_var
(
"SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"
):
assert_pkg_version
(
assert_pkg_version
(
"sgl-kernel"
,
"sgl-kernel"
,
"0.3.1
4
"
,
"0.3.1
5
"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
)
)
...
...
python/sglang/srt/layers/attention/attention_registry.py
View file @
edefab0c
...
@@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner):
...
@@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner):
@
register_attention_backend
(
"fa4"
)
@
register_attention_backend
(
"fa4"
)
def
create_flashattention_v4_backend
(
runner
):
def
create_flashattention_v4_backend
(
runner
):
assert
(
runner
.
use_mla_backend
),
"FlashAttention v4 Support is at an early stage, only MLA model supported now"
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
return
FlashAttentionBackend
(
runner
,
fa_impl_ver
=
4
)
return
FlashAttentionBackend
(
runner
,
fa_impl_ver
=
4
)
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
edefab0c
...
@@ -754,7 +754,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -754,7 +754,6 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill
# Use Flash Attention for prefill
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
assert
self
.
fa_impl_ver
in
[
3
],
"Only FA3 support here"
# Do multi-head attention
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
layer
.
layer_id
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
edefab0c
...
@@ -1746,16 +1746,10 @@ class ModelRunner:
...
@@ -1746,16 +1746,10 @@ class ModelRunner:
def
_get_attention_backend
(
self
):
def
_get_attention_backend
(
self
):
"""Init attention kernel backend."""
"""Init attention kernel backend."""
self
.
decode_attention_backend_str
=
(
self
.
prefill_attention_backend_str
,
self
.
decode_attention_backend_str
=
(
self
.
server_args
.
decode_attention_backend
self
.
server_args
.
get_attention_backends
()
if
self
.
server_args
.
decode_attention_backend
else
self
.
server_args
.
attention_backend
)
self
.
prefill_attention_backend_str
=
(
self
.
server_args
.
prefill_attention_backend
if
self
.
server_args
.
prefill_attention_backend
else
self
.
server_args
.
attention_backend
)
)
if
self
.
decode_attention_backend_str
!=
self
.
prefill_attention_backend_str
:
if
self
.
decode_attention_backend_str
!=
self
.
prefill_attention_backend_str
:
from
sglang.srt.layers.attention.hybrid_attn_backend
import
(
from
sglang.srt.layers.attention.hybrid_attn_backend
import
(
HybridAttnBackend
,
HybridAttnBackend
,
...
...
python/sglang/srt/server_args.py
View file @
edefab0c
...
@@ -464,6 +464,19 @@ class ServerArgs:
...
@@ -464,6 +464,19 @@ class ServerArgs:
enable_pdmux
:
bool
=
False
enable_pdmux
:
bool
=
False
sm_group_num
:
int
=
3
sm_group_num
:
int
=
3
def
get_attention_backends
(
server_args
):
prefill_attention_backend_str
=
(
server_args
.
prefill_attention_backend
if
server_args
.
prefill_attention_backend
else
server_args
.
attention_backend
)
decode_attention_backend_str
=
(
server_args
.
decode_attention_backend
if
server_args
.
decode_attention_backend
else
server_args
.
attention_backend
)
return
prefill_attention_backend_str
,
decode_attention_backend_str
def
__post_init__
(
self
):
def
__post_init__
(
self
):
"""
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
...
@@ -748,20 +761,28 @@ class ServerArgs:
...
@@ -748,20 +761,28 @@ class ServerArgs:
hf_config
=
self
.
get_hf_config
()
hf_config
=
self
.
get_hf_config
()
model_arch
=
hf_config
.
architectures
[
0
]
model_arch
=
hf_config
.
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
self
.
attention_backend
is
None
:
if
(
self
.
attention_backend
is
None
and
self
.
prefill_attention_backend
is
None
and
self
.
decode_attention_backend
is
None
):
if
is_cuda
()
and
is_sm100_supported
():
if
is_cuda
()
and
is_sm100_supported
():
self
.
attention_backend
=
"trtllm_mha"
self
.
attention_backend
=
"trtllm_mha"
elif
is_cuda
()
and
is_sm90_supported
():
elif
is_cuda
()
and
is_sm90_supported
():
self
.
attention_backend
=
"fa3"
self
.
attention_backend
=
"fa3"
else
:
else
:
self
.
attention_backend
=
"triton"
self
.
attention_backend
=
"triton"
supported_backends
=
[
"triton"
,
"trtllm_mha"
,
"fa3"
]
logger
.
info
(
supported_backends
=
[
"triton"
,
"trtllm_mha"
,
"fa3"
,
"fa4"
]
f
"Use
{
self
.
attention_backend
}
as attention backend for GptOssForCausalLM"
prefill_attn_backend
,
decode_attn_backend
=
self
.
get_attention_backends
()
)
assert
(
assert
(
self
.
attention_backend
in
supported_backends
prefill_attn_backend
in
supported_backends
),
f
"GptOssForCausalLM requires one of
{
supported_backends
}
attention backend, but got '
{
self
.
attention_backend
}
'"
and
decode_attn_backend
in
supported_backends
),
(
f
"GptOssForCausalLM requires one of
{
supported_backends
}
attention backend, but got the following backends
\n
"
f
"- Prefill:
{
prefill_attn_backend
}
\n
"
f
"- Decode:
{
decode_attn_backend
}
\n
"
)
if
is_sm100_supported
():
if
is_sm100_supported
():
if
not
self
.
enable_dp_attention
:
if
not
self
.
enable_dp_attention
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment