Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cba0d8c3
"vscode:/vscode.git/clone" did not exist on "53959eeeb9c5411ecc37f0de90159e29f6310a49"
Unverified
Commit
cba0d8c3
authored
Sep 20, 2025
by
Stefan He
Committed by
GitHub
Sep 20, 2025
Browse files
[Feature] Support deterministic inference with FA3 backend (#10651)
parent
f1d78923
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
6 deletions
+25
-6
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+17
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-6
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
cba0d8c3
...
@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
-
1
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
-
1
)
)
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
# We set nums splits to 1 if deterministic inference is enabled.
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
self
.
num_splits
=
(
1
if
model_runner
.
server_args
.
enable_deterministic_inference
else
0
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata
=
FlashAttentionMetadata
()
metadata
=
FlashAttentionMetadata
()
...
@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
_
=
merge_state_v2_wrapper
(
...
@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
)
)
if
use_cascade_attn
:
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o
,
softmax_lse
,
*
rest
=
result
...
@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
)
)
)
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
_
=
merge_state_v2_wrapper
(
...
@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
elif
use_local_attn
:
elif
use_local_attn
:
...
@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
...
@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
if
use_cascade_attn
:
if
use_cascade_attn
:
...
@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
)
)
...
@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
# softmax_lse is needed for merge states
return_softmax_lse
=
use_cascade_attn
,
# softmax_lse is needed for merge states
num_splits
=
self
.
num_splits
,
)
)
if
use_cascade_attn
:
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o
,
softmax_lse
,
*
rest
=
result
...
@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
)
)
o
,
_
=
merge_state_v2
(
o
,
_
=
merge_state_v2
(
o
,
o
,
...
...
python/sglang/srt/server_args.py
View file @
cba0d8c3
...
@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
...
@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES
=
[
"xgrammar"
,
"outlines"
,
"llguidance"
,
"none"
]
GRAMMAR_BACKEND_CHOICES
=
[
"xgrammar"
,
"outlines"
,
"llguidance"
,
"none"
]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES
=
[
"flashinfer"
]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES
=
[
"flashinfer"
,
"fa3"
]
# Allow external code to add more choices
# Allow external code to add more choices
...
@@ -998,11 +998,13 @@ class ServerArgs:
...
@@ -998,11 +998,13 @@ class ServerArgs:
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
)
)
# Check some settings
# Currently, only FA3 supports radix cache. Support for other backends is in progress
self
.
disable_radix_cache
=
True
if
self
.
attention_backend
!=
"fa3"
:
logger
.
warning
(
self
.
disable_radix_cache
=
True
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
logger
.
warning
(
)
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
)
if
self
.
attention_backend
not
in
DETERMINISTIC_ATTENTION_BACKEND_CHOICES
:
if
self
.
attention_backend
not
in
DETERMINISTIC_ATTENTION_BACKEND_CHOICES
:
raise
ValueError
(
raise
ValueError
(
f
"Currently only
{
DETERMINISTIC_ATTENTION_BACKEND_CHOICES
}
attention backends are supported for deterministic inference."
f
"Currently only
{
DETERMINISTIC_ATTENTION_BACKEND_CHOICES
}
attention backends are supported for deterministic inference."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment