Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a2f7218a
Unverified
Commit
a2f7218a
authored
Sep 17, 2025
by
cicirori
Committed by
GitHub
Sep 16, 2025
Browse files
support using fa4 on deepseek on blackwell (#9928)
parent
311de47b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
136 additions
and
0 deletions
+136
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+7
-0
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+12
-0
python/sglang/srt/layers/attention/hybrid_attn_backend.py
python/sglang/srt/layers/attention/hybrid_attn_backend.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+10
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
sgl-kernel/python/sgl_kernel/_fa4_interface.py
sgl-kernel/python/sgl_kernel/_fa4_interface.py
+102
-0
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
a2f7218a
...
@@ -666,6 +666,13 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -666,6 +666,13 @@ def _set_envs_and_config(server_args: ServerArgs):
if
os
.
environ
.
get
(
"TRTLLM_ENABLE_PDL"
,
"1"
)
!=
"0"
:
if
os
.
environ
.
get
(
"TRTLLM_ENABLE_PDL"
,
"1"
)
!=
"0"
:
os
.
environ
[
"TRTLLM_ENABLE_PDL"
]
=
"1"
os
.
environ
[
"TRTLLM_ENABLE_PDL"
]
=
"1"
if
os
.
environ
.
get
(
"CUTE_DSL_LOG_LEVEL"
)
is
None
:
# Default to warning level, to avoid too many logs
os
.
environ
[
"CUTE_DSL_LOG_LEVEL"
]
=
"30"
if
os
.
environ
.
get
(
"CUTE_DSL_LOG_TO_CONSOLE"
)
is
None
:
# Need to set log to console, otherwise the log level won't take effect
os
.
environ
[
"CUTE_DSL_LOG_TO_CONSOLE"
]
=
"1"
# Can also be passed as argument
# Can also be passed as argument
os
.
environ
[
"SGLANG_RUN_ID"
]
=
(
os
.
environ
[
"SGLANG_RUN_ID"
]
=
(
f
"sglang-run-
{
time
.
time
()
}
-
{
random
.
randint
(
0
,
100000000
)
}
"
f
"sglang-run-
{
time
.
time
()
}
-
{
random
.
randint
(
0
,
100000000
)
}
"
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a2f7218a
...
@@ -305,6 +305,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -305,6 +305,7 @@ class FlashAttentionBackend(AttentionBackend):
speculative_step_id
=
0
,
speculative_step_id
=
0
,
topk
=
0
,
topk
=
0
,
speculative_num_steps
=
0
,
speculative_num_steps
=
0
,
fa_impl_ver
=
3
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -338,6 +339,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -338,6 +339,8 @@ class FlashAttentionBackend(AttentionBackend):
)
)
self
.
speculative_step_id
=
speculative_step_id
self
.
speculative_step_id
=
speculative_step_id
self
.
fa_impl_ver
=
fa_impl_ver
# Local attention settings
# Local attention settings
self
.
attention_chunk_size
=
(
self
.
attention_chunk_size
=
(
model_runner
.
attention_chunk_size
model_runner
.
attention_chunk_size
...
@@ -712,6 +715,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -712,6 +715,8 @@ class FlashAttentionBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
kwargs
=
{}
if
self
.
fa_impl_ver
!=
3
:
kwargs
[
"ver"
]
=
self
.
fa_impl_ver
if
sinks
is
not
None
:
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
kwargs
[
"sinks"
]
=
sinks
...
@@ -738,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -738,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill
# Use Flash Attention for prefill
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
assert
self
.
fa_impl_ver
in
[
3
],
"Only FA3 support here"
# Do multi-head attention
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
layer
.
layer_id
...
@@ -830,6 +836,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -830,6 +836,7 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
causal
=
False
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
**
kwargs
,
)
)
else
:
else
:
# MHA for extend part of sequence without attending prefix kv cache
# MHA for extend part of sequence without attending prefix kv cache
...
@@ -844,6 +851,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -844,6 +851,7 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
True
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
)
)
if
forward_batch
.
mha_return_lse
:
if
forward_batch
.
mha_return_lse
:
output
,
lse
,
*
rest
=
output
output
,
lse
,
*
rest
=
output
...
@@ -851,6 +859,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -851,6 +859,7 @@ class FlashAttentionBackend(AttentionBackend):
return
output
,
lse
return
output
,
lse
return
output
return
output
else
:
else
:
assert
self
.
fa_impl_ver
in
[
3
],
"Only FA3 support here"
# Do absorbed multi-latent attention
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
layer
.
layer_id
...
@@ -939,6 +948,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -939,6 +948,7 @@ class FlashAttentionBackend(AttentionBackend):
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
self
.
fa_impl_ver
in
[
3
],
"Only FA3 support decoding"
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
...
@@ -985,6 +995,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -985,6 +995,8 @@ class FlashAttentionBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
kwargs
=
{}
if
self
.
fa_impl_ver
!=
3
:
kwargs
[
"ver"
]
=
self
.
fa_impl_ver
if
sinks
is
not
None
:
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
kwargs
[
"sinks"
]
=
sinks
...
...
python/sglang/srt/layers/attention/hybrid_attn_backend.py
View file @
a2f7218a
...
@@ -21,6 +21,7 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -21,6 +21,7 @@ class HybridAttnBackend(AttentionBackend):
self
.
model_runner
=
model_runner
self
.
model_runner
=
model_runner
self
.
prefill_backend
=
prefill_backend
self
.
prefill_backend
=
prefill_backend
self
.
decode_backend
=
decode_backend
self
.
decode_backend
=
decode_backend
self
.
data_type
=
model_runner
.
kv_cache_dtype
def
_select_backend
(
self
,
forward_mode
:
ForwardMode
)
->
AttentionBackend
:
def
_select_backend
(
self
,
forward_mode
:
ForwardMode
)
->
AttentionBackend
:
"""
"""
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a2f7218a
...
@@ -516,6 +516,7 @@ class ModelRunner:
...
@@ -516,6 +516,7 @@ class ModelRunner:
"aiter"
,
"aiter"
,
"flashinfer"
,
"flashinfer"
,
"fa3"
,
"fa3"
,
"fa4"
,
"triton"
,
"triton"
,
"flashmla"
,
"flashmla"
,
"cutlass_mla"
,
"cutlass_mla"
,
...
@@ -1800,6 +1801,15 @@ class ModelRunner:
...
@@ -1800,6 +1801,15 @@ class ModelRunner:
)
)
return
FlashAttentionBackend
(
self
)
return
FlashAttentionBackend
(
self
)
elif
backend_str
==
"fa4"
:
assert
(
self
.
use_mla_backend
),
"FlashAttention v4 Support is at an early stage, only MLA model supported now"
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
FlashAttentionBackend
(
self
,
fa_impl_ver
=
4
)
elif
backend_str
==
"cutlass_mla"
:
elif
backend_str
==
"cutlass_mla"
:
from
sglang.srt.layers.attention.cutlass_mla_backend
import
(
from
sglang.srt.layers.attention.cutlass_mla_backend
import
(
CutlassMLABackend
,
CutlassMLABackend
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
a2f7218a
...
@@ -1124,6 +1124,9 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1124,6 +1124,9 @@ class DeepseekV2AttentionMLA(nn.Module):
return
AttnForwardMethod
.
MHA_CHUNKED_KV
return
AttnForwardMethod
.
MHA_CHUNKED_KV
else
:
else
:
return
_dispatch_mla_subtype
()
return
_dispatch_mla_subtype
()
elif
attention_backend
==
"fa4"
:
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return
AttnForwardMethod
.
MHA_CHUNKED_KV
elif
attention_backend
==
"trtllm_mla"
:
elif
attention_backend
==
"trtllm_mla"
:
original_mode
=
getattr
(
forward_batch
,
"_original_forward_mode"
,
None
)
original_mode
=
getattr
(
forward_batch
,
"_original_forward_mode"
,
None
)
if
(
if
(
...
...
python/sglang/srt/server_args.py
View file @
a2f7218a
...
@@ -96,6 +96,7 @@ ATTENTION_BACKEND_CHOICES = [
...
@@ -96,6 +96,7 @@ ATTENTION_BACKEND_CHOICES = [
# NVIDIA specific
# NVIDIA specific
"cutlass_mla"
,
"cutlass_mla"
,
"fa3"
,
"fa3"
,
"fa4"
,
"flashinfer"
,
"flashinfer"
,
"flashmla"
,
"flashmla"
,
"trtllm_mla"
,
"trtllm_mla"
,
...
...
sgl-kernel/python/sgl_kernel/_fa4_interface.py
View file @
a2f7218a
...
@@ -4,9 +4,15 @@
...
@@ -4,9 +4,15 @@
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
import
copy
import
gc
import
logging
import
math
import
math
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
logger
=
logging
.
getLogger
(
__name__
)
import
cuda.bindings.driver
as
cuda
import
cuda.bindings.driver
as
cuda
import
cutlass
import
cutlass
import
cutlass.cute
as
cute
import
cutlass.cute
as
cute
...
@@ -20,6 +26,22 @@ def maybe_contiguous(x):
...
@@ -20,6 +26,22 @@ def maybe_contiguous(x):
return
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
return
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
def
_reason_recompile
(
compile_key
,
jit_func
):
compile_cache
=
jit_func
.
compile_cache
compile_key_map
=
jit_func
.
compile_key_map
if
not
compile_cache
:
return
"not compiled yet"
for
k
,
v
in
compile_cache
.
items
():
if
k
==
compile_key
:
continue
if
len
(
k
)
!=
len
(
compile_key
):
continue
for
i
in
range
(
len
(
k
)):
if
k
[
i
]
!=
compile_key
[
i
]:
return
f
"diff at '
{
compile_key_map
[
i
]
}
':
{
k
[
i
]
}
vs
{
compile_key
[
i
]
}
"
return
"unknown reason"
torch2cute_dtype_map
=
{
torch2cute_dtype_map
=
{
torch
.
float16
:
cutlass
.
Float16
,
torch
.
float16
:
cutlass
.
Float16
,
torch
.
bfloat16
:
cutlass
.
BFloat16
,
torch
.
bfloat16
:
cutlass
.
BFloat16
,
...
@@ -254,6 +276,9 @@ def _flash_attn_fwd(
...
@@ -254,6 +276,9 @@ def _flash_attn_fwd(
compute_capability
,
compute_capability
,
)
)
if
compile_key
not
in
_flash_attn_fwd
.
compile_cache
:
if
compile_key
not
in
_flash_attn_fwd
.
compile_cache
:
logger
.
info
(
f
"Compiling FA4 kernel with reason:
{
_reason_recompile
(
compile_key
,
_flash_attn_fwd
)
}
"
)
if
compute_capability
==
9
:
if
compute_capability
==
9
:
assert
page_table
is
None
,
"paged KV not supported on SM 9.0"
assert
page_table
is
None
,
"paged KV not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80(
# fa_fwd = FlashAttentionForwardSm80(
...
@@ -335,8 +360,85 @@ def _flash_attn_fwd(
...
@@ -335,8 +360,85 @@ def _flash_attn_fwd(
_flash_attn_fwd
.
compile_cache
=
{}
_flash_attn_fwd
.
compile_cache
=
{}
_flash_attn_fwd
.
compile_key_map
=
[
"dtype"
,
"head_dim"
,
"head_dim_v"
,
"qhead_per_kvhead"
,
"causal"
,
"softcap is not None"
,
"lse is None"
,
"cu_seqlens_q is None"
,
"cu_seqlens_k is None"
,
"seqused_q is None"
,
"seqused_k is None"
,
"page_table is not None"
,
"window_size_left is not None"
,
"window_size_right is not None"
,
"learnable_sink is not None"
,
"m_block_size"
,
"n_block_size"
,
"num_threads"
,
"pack_gqa"
,
"compute_capability"
,
]
def
warmup_flash_attn
(
f
):
"""
Decorator for flash_attn_varlen_func:
- On the first call, run several warmup passes with different flag combinations
- Warmups are executed sequentially to minimize peak GPU memory usage
- Does not modify user-provided tensors (clones data)
- Easy to extend with more compile-key dimensions
"""
done
=
False
def
_clone_args
(
args
,
kwargs
):
"""Clone tensor arguments to avoid sharing storage; deepcopy for others."""
def
maybe_clone
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
return
x
.
clone
()
return
copy
.
deepcopy
(
x
)
return
tuple
(
maybe_clone
(
a
)
for
a
in
args
),
{
k
:
maybe_clone
(
v
)
for
k
,
v
in
kwargs
.
items
()
}
def
_run_warmups
(
args
,
kwargs
):
"""Run warmup calls sequentially and release memory after each."""
base_args
,
base_kwargs
=
_clone_args
(
args
,
kwargs
)
# Warmup combinations for return_softmax_lse and causal
combos
=
[
dict
(
return_softmax_lse
=
False
,
causal
=
False
),
dict
(
return_softmax_lse
=
False
,
causal
=
True
),
dict
(
return_softmax_lse
=
True
,
causal
=
False
),
dict
(
return_softmax_lse
=
True
,
causal
=
True
),
]
for
combo
in
combos
:
wa
,
wk
=
_clone_args
(
base_args
,
base_kwargs
)
wk
.
update
(
combo
)
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
f
(
*
wa
,
**
wk
)
del
wa
,
wk
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
done
if
not
done
:
logger
.
info
(
"Running flash_attn_varlen_func warmup passes..."
)
_run_warmups
(
args
,
kwargs
)
done
=
True
return
f
(
*
args
,
**
kwargs
)
return
wrapper
@
warmup_flash_attn
def
flash_attn_varlen_func
(
def
flash_attn_varlen_func
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment