Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ef4a8097
Unverified
Commit
ef4a8097
authored
Oct 21, 2025
by
Baizhou Zhang
Committed by
GitHub
Oct 21, 2025
Browse files
Rename flashmla kernel options of nsa backend for better readability (#11876)
parent
ebff4ee6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
31 deletions
+31
-31
docs/advanced_features/server_arguments.md
docs/advanced_features/server_arguments.md
+2
-0
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+19
-21
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-10
No files found.
docs/advanced_features/server_arguments.md
View file @
ef4a8097
...
...
@@ -228,6 +228,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
`--sampling-backend`
| Choose the kernels for sampling layers. | None |
|
`--grammar-backend`
| Choose the backend for grammar-guided decoding. | None |
|
`--mm-attention-backend`
| Set multimodal attention backend. | None |
|
`--nsa-prefill-backend`
| Prefill attention implementation for nsa backend. |
`flashmla_sparse`
|
|
`--nsa-decode-backend`
| Decode attention implementation for nsa backend. |
`flashmla_kv`
|
## Speculative decoding
...
...
python/sglang/srt/layers/attention/nsa_backend.py
View file @
ef4a8097
...
...
@@ -140,9 +140,7 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
)
_NSA_IMPL_T
:
TypeAlias
=
Literal
[
"flashmla_prefill"
,
"flashmla_decode"
,
"fa3"
,
"tilelang"
]
_NSA_IMPL_T
:
TypeAlias
=
Literal
[
"flashmla_sparse"
,
"flashmla_kv"
,
"fa3"
,
"tilelang"
]
NSA_PREFILL_IMPL
:
_NSA_IMPL_T
NSA_DECODE_IMPL
:
_NSA_IMPL_T
...
...
@@ -181,8 +179,8 @@ class NativeSparseAttnBackend(AttentionBackend):
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
global
NSA_PREFILL_IMPL
,
NSA_DECODE_IMPL
NSA_PREFILL_IMPL
=
model_runner
.
server_args
.
nsa_prefill
NSA_DECODE_IMPL
=
model_runner
.
server_args
.
nsa_decode
NSA_PREFILL_IMPL
=
model_runner
.
server_args
.
nsa_prefill
_backend
NSA_DECODE_IMPL
=
model_runner
.
server_args
.
nsa_decode
_backend
self
.
_arange_buf
=
torch
.
arange
(
16384
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
...
...
@@ -336,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend):
cache_seqlens
=
nsa_cache_seqlens_int32
,
seq_len_q
=
1
,
)
if
NSA_DECODE_IMPL
==
"flashmla_
decode
"
if
NSA_DECODE_IMPL
==
"flashmla_
kv
"
else
None
),
nsa_cache_seqlens_int32
=
nsa_cache_seqlens_int32
,
...
...
@@ -383,7 +381,7 @@ class NativeSparseAttnBackend(AttentionBackend):
),
seq_len_q
=
1
,
)
if
NSA_DECODE_IMPL
==
"flashmla_
decode
"
if
NSA_DECODE_IMPL
==
"flashmla_
kv
"
else
None
),
}
...
...
@@ -421,7 +419,7 @@ class NativeSparseAttnBackend(AttentionBackend):
seqlens_expanded
=
cache_seqlens_int32
nsa_extend_seq_lens_list
=
[
1
]
*
num_tokens
if
NSA_DECODE_IMPL
==
"flashmla_
decode
"
:
if
NSA_DECODE_IMPL
==
"flashmla_
kv
"
:
flashmla_metadata
=
self
.
decode_cuda_graph_metadata
[
"flashmla_metadata"
].
slice
(
slice
(
0
,
num_tokens
+
1
))
...
...
@@ -478,7 +476,7 @@ class NativeSparseAttnBackend(AttentionBackend):
)
nsa_extend_seq_lens_list
=
[
1
]
*
bs
*
self
.
speculative_num_draft_tokens
if
NSA_DECODE_IMPL
==
"flashmla_
decode
"
:
if
NSA_DECODE_IMPL
==
"flashmla_
kv
"
:
flashmla_metadata
=
self
.
decode_cuda_graph_metadata
[
"flashmla_metadata"
].
slice
(
slice
(
0
,
bs
*
self
.
speculative_num_draft_tokens
+
1
))
...
...
@@ -534,7 +532,7 @@ class NativeSparseAttnBackend(AttentionBackend):
)
nsa_extend_seq_lens_list
=
[
1
]
*
bs
if
NSA_DECODE_IMPL
==
"flashmla_
decode
"
:
if
NSA_DECODE_IMPL
==
"flashmla_
kv
"
:
flashmla_metadata
=
self
.
decode_cuda_graph_metadata
[
"flashmla_metadata"
].
slice
(
slice
(
0
,
bs
*
self
.
speculative_num_draft_tokens
+
1
))
...
...
@@ -712,7 +710,7 @@ class NativeSparseAttnBackend(AttentionBackend):
else
:
assert
metadata
.
real_page_table
is
metadata
.
page_table_1
if
NSA_DECODE_IMPL
==
"flashmla_
decode
"
:
if
NSA_DECODE_IMPL
==
"flashmla_
kv
"
:
flashmla_metadata
=
metadata
.
flashmla_metadata
.
slice
(
slice
(
0
,
seqlens_expanded_size
+
1
)
)
...
...
@@ -803,20 +801,20 @@ class NativeSparseAttnBackend(AttentionBackend):
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
)
elif
NSA_PREFILL_IMPL
==
"flashmla_
prefill
"
:
elif
NSA_PREFILL_IMPL
==
"flashmla_
sparse
"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_flashmla_
prefill
(
return
self
.
_forward_flashmla_
sparse
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
page_table_1
=
page_table_1
,
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
)
elif
NSA_PREFILL_IMPL
==
"flashmla_
decode
"
:
elif
NSA_PREFILL_IMPL
==
"flashmla_
kv
"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_flashmla_
decode
(
return
self
.
_forward_flashmla_
kv
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
sm_scale
=
layer
.
scaling
,
...
...
@@ -897,20 +895,20 @@ class NativeSparseAttnBackend(AttentionBackend):
page_size
=
1
,
)
if
NSA_DECODE_IMPL
==
"flashmla_
prefill
"
:
if
NSA_DECODE_IMPL
==
"flashmla_
sparse
"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_flashmla_
prefill
(
return
self
.
_forward_flashmla_
sparse
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
page_table_1
=
page_table_1
,
sm_scale
=
layer
.
scaling
,
v_head_dim
=
layer
.
v_head_dim
,
)
elif
NSA_DECODE_IMPL
==
"flashmla_
decode
"
:
elif
NSA_DECODE_IMPL
==
"flashmla_
kv
"
:
if
q_rope
is
not
None
:
q_all
=
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
return
self
.
_forward_flashmla_
decode
(
return
self
.
_forward_flashmla_
kv
(
q_all
=
q_all
,
kv_cache
=
kv_cache
,
sm_scale
=
layer
.
scaling
,
...
...
@@ -998,7 +996,7 @@ class NativeSparseAttnBackend(AttentionBackend):
)
return
o
# type: ignore
def
_forward_flashmla_
prefill
(
def
_forward_flashmla_
sparse
(
self
,
q_all
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
...
...
@@ -1017,7 +1015,7 @@ class NativeSparseAttnBackend(AttentionBackend):
)
return
o
def
_forward_flashmla_
decode
(
def
_forward_flashmla_
kv
(
self
,
q_all
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
...
...
python/sglang/srt/server_args.py
View file @
ef4a8097
...
...
@@ -128,7 +128,7 @@ DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
DEFAULT_LORA_EVICTION_POLICY
=
"lru"
NSA_CHOICES
=
[
"flashmla_
prefill
"
,
"flashmla_
decode
"
,
"fa3"
,
"tilelang"
,
"aiter"
]
NSA_CHOICES
=
[
"flashmla_
sparse
"
,
"flashmla_
kv
"
,
"fa3"
,
"tilelang"
,
"aiter"
]
RADIX_EVICTION_POLICY_CHOICES
=
[
"lru"
,
"lfu"
]
...
...
@@ -324,8 +324,8 @@ class ServerArgs:
sampling_backend
:
Optional
[
str
]
=
None
grammar_backend
:
Optional
[
str
]
=
None
mm_attention_backend
:
Optional
[
str
]
=
None
nsa_prefill
:
str
=
"flashmla_
prefill
"
nsa_decode
:
str
=
"fa3"
nsa_prefill
_backend
:
str
=
"flashmla_
sparse
"
nsa_decode
_backend
:
str
=
"fa3"
# Speculative decoding
enable_beta_spec
:
bool
=
False
...
...
@@ -1024,10 +1024,10 @@ class ServerArgs:
logger
.
warning
(
"Setting KV cache dtype to fp8."
)
if
self
.
kv_cache_dtype
==
"fp8_e4m3"
:
self
.
nsa_prefill
=
"flashmla_
decode
"
self
.
nsa_decode
=
"flashmla_
decode
"
self
.
nsa_prefill
_backend
=
"flashmla_
kv
"
self
.
nsa_decode
_backend
=
"flashmla_
kv
"
logger
.
warning
(
"Setting NSA backend to flashmla_
decode
for FP8 KV Cache."
"Setting NSA backend to flashmla_
kv
for FP8 KV Cache."
)
# Logging env vars for NSA
...
...
@@ -2356,14 +2356,14 @@ class ServerArgs:
help
=
"Set multimodal attention backend."
,
)
parser
.
add_argument
(
"--nsa-prefill"
,
default
=
ServerArgs
.
nsa_prefill
,
"--nsa-prefill
-backend
"
,
default
=
ServerArgs
.
nsa_prefill
_backend
,
type
=
str
,
choices
=
NSA_CHOICES
,
)
parser
.
add_argument
(
"--nsa-decode"
,
default
=
ServerArgs
.
nsa_decode
,
"--nsa-decode
-backend
"
,
default
=
ServerArgs
.
nsa_decode
_backend
,
type
=
str
,
choices
=
NSA_CHOICES
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment