Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
a2f7218a
"vscode:/vscode.git/clone" did not exist on "564c4ce1e362875158993922cf2796fc5a51f321"
Unverified
Commit
a2f7218a
authored
Sep 17, 2025
by
cicirori
Committed by
GitHub
Sep 16, 2025
Browse files
support using fa4 on deepseek on blackwell (#9928)
parent
311de47b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
136 additions
and
0 deletions
+136
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+7
-0
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+12
-0
python/sglang/srt/layers/attention/hybrid_attn_backend.py
python/sglang/srt/layers/attention/hybrid_attn_backend.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+10
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
sgl-kernel/python/sgl_kernel/_fa4_interface.py
sgl-kernel/python/sgl_kernel/_fa4_interface.py
+102
-0
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
a2f7218a
...
...
@@ -666,6 +666,13 @@ def _set_envs_and_config(server_args: ServerArgs):
if
os
.
environ
.
get
(
"TRTLLM_ENABLE_PDL"
,
"1"
)
!=
"0"
:
os
.
environ
[
"TRTLLM_ENABLE_PDL"
]
=
"1"
if
os
.
environ
.
get
(
"CUTE_DSL_LOG_LEVEL"
)
is
None
:
# Default to warning level, to avoid too many logs
os
.
environ
[
"CUTE_DSL_LOG_LEVEL"
]
=
"30"
if
os
.
environ
.
get
(
"CUTE_DSL_LOG_TO_CONSOLE"
)
is
None
:
# Need to set log to console, otherwise the log level won't take effect
os
.
environ
[
"CUTE_DSL_LOG_TO_CONSOLE"
]
=
"1"
# Can also be passed as argument
os
.
environ
[
"SGLANG_RUN_ID"
]
=
(
f
"sglang-run-
{
time
.
time
()
}
-
{
random
.
randint
(
0
,
100000000
)
}
"
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a2f7218a
...
...
@@ -305,6 +305,7 @@ class FlashAttentionBackend(AttentionBackend):
speculative_step_id
=
0
,
topk
=
0
,
speculative_num_steps
=
0
,
fa_impl_ver
=
3
,
):
super
().
__init__
()
...
...
@@ -338,6 +339,8 @@ class FlashAttentionBackend(AttentionBackend):
)
self
.
speculative_step_id
=
speculative_step_id
self
.
fa_impl_ver
=
fa_impl_ver
# Local attention settings
self
.
attention_chunk_size
=
(
model_runner
.
attention_chunk_size
...
...
@@ -712,6 +715,8 @@ class FlashAttentionBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
self
.
fa_impl_ver
!=
3
:
kwargs
[
"ver"
]
=
self
.
fa_impl_ver
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
...
...
@@ -738,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
# Use Flash Attention for prefill
if
not
self
.
use_mla
:
assert
self
.
fa_impl_ver
in
[
3
],
"Only FA3 support here"
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
...
...
@@ -830,6 +836,7 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
return_softmax_lse
=
True
,
**
kwargs
,
)
else
:
# MHA for extend part of sequence without attending prefix kv cache
...
...
@@ -844,6 +851,7 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
)
if
forward_batch
.
mha_return_lse
:
output
,
lse
,
*
rest
=
output
...
...
@@ -851,6 +859,7 @@ class FlashAttentionBackend(AttentionBackend):
return
output
,
lse
return
output
else
:
assert
self
.
fa_impl_ver
in
[
3
],
"Only FA3 support here"
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
...
...
@@ -939,6 +948,7 @@ class FlashAttentionBackend(AttentionBackend):
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
self
.
fa_impl_ver
in
[
3
],
"Only FA3 support decoding"
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
...
...
@@ -985,6 +995,8 @@ class FlashAttentionBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
self
.
fa_impl_ver
!=
3
:
kwargs
[
"ver"
]
=
self
.
fa_impl_ver
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
...
...
python/sglang/srt/layers/attention/hybrid_attn_backend.py
View file @
a2f7218a
...
...
@@ -21,6 +21,7 @@ class HybridAttnBackend(AttentionBackend):
self
.
model_runner
=
model_runner
self
.
prefill_backend
=
prefill_backend
self
.
decode_backend
=
decode_backend
self
.
data_type
=
model_runner
.
kv_cache_dtype
def
_select_backend
(
self
,
forward_mode
:
ForwardMode
)
->
AttentionBackend
:
"""
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a2f7218a
...
...
@@ -516,6 +516,7 @@ class ModelRunner:
"aiter"
,
"flashinfer"
,
"fa3"
,
"fa4"
,
"triton"
,
"flashmla"
,
"cutlass_mla"
,
...
...
@@ -1800,6 +1801,15 @@ class ModelRunner:
)
return
FlashAttentionBackend
(
self
)
elif
backend_str
==
"fa4"
:
assert
(
self
.
use_mla_backend
),
"FlashAttention v4 Support is at an early stage, only MLA model supported now"
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
FlashAttentionBackend
(
self
,
fa_impl_ver
=
4
)
elif
backend_str
==
"cutlass_mla"
:
from
sglang.srt.layers.attention.cutlass_mla_backend
import
(
CutlassMLABackend
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
a2f7218a
...
...
@@ -1124,6 +1124,9 @@ class DeepseekV2AttentionMLA(nn.Module):
return
AttnForwardMethod
.
MHA_CHUNKED_KV
else
:
return
_dispatch_mla_subtype
()
elif
attention_backend
==
"fa4"
:
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return
AttnForwardMethod
.
MHA_CHUNKED_KV
elif
attention_backend
==
"trtllm_mla"
:
original_mode
=
getattr
(
forward_batch
,
"_original_forward_mode"
,
None
)
if
(
...
...
python/sglang/srt/server_args.py
View file @
a2f7218a
...
...
@@ -96,6 +96,7 @@ ATTENTION_BACKEND_CHOICES = [
# NVIDIA specific
"cutlass_mla"
,
"fa3"
,
"fa4"
,
"flashinfer"
,
"flashmla"
,
"trtllm_mla"
,
...
...
sgl-kernel/python/sgl_kernel/_fa4_interface.py
View file @
a2f7218a
...
...
@@ -4,9 +4,15 @@
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
import
copy
import
gc
import
logging
import
math
from
typing
import
Optional
,
Tuple
logger
=
logging
.
getLogger
(
__name__
)
import
cuda.bindings.driver
as
cuda
import
cutlass
import
cutlass.cute
as
cute
...
...
@@ -20,6 +26,22 @@ def maybe_contiguous(x):
return
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
def
_reason_recompile
(
compile_key
,
jit_func
):
compile_cache
=
jit_func
.
compile_cache
compile_key_map
=
jit_func
.
compile_key_map
if
not
compile_cache
:
return
"not compiled yet"
for
k
,
v
in
compile_cache
.
items
():
if
k
==
compile_key
:
continue
if
len
(
k
)
!=
len
(
compile_key
):
continue
for
i
in
range
(
len
(
k
)):
if
k
[
i
]
!=
compile_key
[
i
]:
return
f
"diff at '
{
compile_key_map
[
i
]
}
':
{
k
[
i
]
}
vs
{
compile_key
[
i
]
}
"
return
"unknown reason"
torch2cute_dtype_map
=
{
torch
.
float16
:
cutlass
.
Float16
,
torch
.
bfloat16
:
cutlass
.
BFloat16
,
...
...
@@ -254,6 +276,9 @@ def _flash_attn_fwd(
compute_capability
,
)
if
compile_key
not
in
_flash_attn_fwd
.
compile_cache
:
logger
.
info
(
f
"Compiling FA4 kernel with reason:
{
_reason_recompile
(
compile_key
,
_flash_attn_fwd
)
}
"
)
if
compute_capability
==
9
:
assert
page_table
is
None
,
"paged KV not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80(
...
...
@@ -335,8 +360,85 @@ def _flash_attn_fwd(
_flash_attn_fwd
.
compile_cache
=
{}
_flash_attn_fwd
.
compile_key_map
=
[
"dtype"
,
"head_dim"
,
"head_dim_v"
,
"qhead_per_kvhead"
,
"causal"
,
"softcap is not None"
,
"lse is None"
,
"cu_seqlens_q is None"
,
"cu_seqlens_k is None"
,
"seqused_q is None"
,
"seqused_k is None"
,
"page_table is not None"
,
"window_size_left is not None"
,
"window_size_right is not None"
,
"learnable_sink is not None"
,
"m_block_size"
,
"n_block_size"
,
"num_threads"
,
"pack_gqa"
,
"compute_capability"
,
]
def
warmup_flash_attn
(
f
):
"""
Decorator for flash_attn_varlen_func:
- On the first call, run several warmup passes with different flag combinations
- Warmups are executed sequentially to minimize peak GPU memory usage
- Does not modify user-provided tensors (clones data)
- Easy to extend with more compile-key dimensions
"""
done
=
False
def
_clone_args
(
args
,
kwargs
):
"""Clone tensor arguments to avoid sharing storage; deepcopy for others."""
def
maybe_clone
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
return
x
.
clone
()
return
copy
.
deepcopy
(
x
)
return
tuple
(
maybe_clone
(
a
)
for
a
in
args
),
{
k
:
maybe_clone
(
v
)
for
k
,
v
in
kwargs
.
items
()
}
def
_run_warmups
(
args
,
kwargs
):
"""Run warmup calls sequentially and release memory after each."""
base_args
,
base_kwargs
=
_clone_args
(
args
,
kwargs
)
# Warmup combinations for return_softmax_lse and causal
combos
=
[
dict
(
return_softmax_lse
=
False
,
causal
=
False
),
dict
(
return_softmax_lse
=
False
,
causal
=
True
),
dict
(
return_softmax_lse
=
True
,
causal
=
False
),
dict
(
return_softmax_lse
=
True
,
causal
=
True
),
]
for
combo
in
combos
:
wa
,
wk
=
_clone_args
(
base_args
,
base_kwargs
)
wk
.
update
(
combo
)
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
f
(
*
wa
,
**
wk
)
del
wa
,
wk
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
done
if
not
done
:
logger
.
info
(
"Running flash_attn_varlen_func warmup passes..."
)
_run_warmups
(
args
,
kwargs
)
done
=
True
return
f
(
*
args
,
**
kwargs
)
return
wrapper
@
warmup_flash_attn
def
flash_attn_varlen_func
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment