Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f235498e
Unverified
Commit
f235498e
authored
Nov 05, 2025
by
YAMY
Committed by
GitHub
Nov 05, 2025
Browse files
DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill (#11892)
parent
149dc9aa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
188 additions
and
4 deletions
+188
-4
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+84
-0
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+61
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+43
-2
No files found.
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
f235498e
...
...
@@ -242,6 +242,30 @@ class Indexer(CustomOp):
return
query
,
key
,
weights
def
_get_k_bf16
(
self
,
x
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
enable_dual_stream
:
bool
,
):
# Compute only key, skip query and weights (weights is discarded if fused)
if
self
.
fuse_wk_and_weights_proj
:
key
,
_
=
self
.
fused_wk_and_weights_proj
(
x
)[
0
].
split
(
[
self
.
head_dim
,
self
.
n_heads
],
dim
=-
1
)
else
:
key
,
_
=
self
.
wk
(
x
)
key
=
self
.
k_norm
(
key
)
k_rope
,
_
=
torch
.
split
(
key
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
)
_
,
k_rope
=
self
.
rotary_emb
(
positions
,
k_rope
,
k_rope
)
key
[...,
:
self
.
rope_head_dim
]
=
k_rope
key
=
rotate_activation
(
key
)
return
key
def
_get_topk_paged
(
self
,
forward_batch
:
ForwardBatch
,
...
...
@@ -375,6 +399,45 @@ class Indexer(CustomOp):
topk_result
[:
offset
]
=
raw_topk_result
return
topk_result
def
_forward_cuda_k_only
(
self
,
x
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
act_quant
,
enable_dual_stream
:
bool
,
metadata
:
BaseIndexerMetadata
,
return_indices
:
bool
=
True
,
)
->
Optional
[
torch
.
Tensor
]:
# Fast path: only compute and store k cache, skip all q and weights ops
key
=
self
.
_get_k_bf16
(
x
,
positions
,
enable_dual_stream
)
k_fp8
,
k_scale
=
act_quant
(
key
,
self
.
block_size
,
self
.
scale_fmt
)
if
not
forward_batch
.
out_cache_loc
.
is_contiguous
():
forward_batch
.
out_cache_loc
=
forward_batch
.
out_cache_loc
.
contiguous
()
forward_batch
.
token_to_kv_pool
.
set_index_k_and_scale_buffer
(
layer_id
=
layer_id
,
loc
=
forward_batch
.
out_cache_loc
,
index_k
=
k_fp8
,
index_k_scale
=
k_scale
,
)
# MHA doesn't need topk_indices
if
not
return_indices
:
return
None
# MLA: use dummy logits with topk kernel's fast path to generate indices
# When length <= 2048, naive_topk_cuda directly generates [0,1,...,length-1,-1,...]
seq_lens_expanded
=
metadata
.
get_seqlens_expanded
()
dummy_logits
=
torch
.
zeros
(
seq_lens_expanded
.
shape
[
0
],
self
.
index_topk
,
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
return
metadata
.
topk_transform
(
dummy_logits
,
self
.
index_topk
)
def
forward_indexer
(
self
,
q_fp8
:
torch
.
Tensor
,
...
...
@@ -465,6 +528,7 @@ class Indexer(CustomOp):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
return_indices
:
bool
=
True
,
)
->
Optional
[
torch
.
Tensor
]:
if
is_hip
():
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
act_quant
...
...
@@ -490,6 +554,26 @@ class Indexer(CustomOp):
if
metadata
is
None
:
return
None
# Determine if should skip topk based on sequence length
should_skip
=
False
if
not
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
seq_lens_cpu
is
not
None
:
max_kv_len
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
should_skip
=
max_kv_len
<=
self
.
index_topk
# Optimization: fast path when skipping topk computation
if
should_skip
:
return
self
.
_forward_cuda_k_only
(
x
,
positions
,
forward_batch
,
layer_id
,
act_quant
,
enable_dual_stream
,
metadata
,
return_indices
,
)
query
,
key
,
weights
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
)
...
...
python/sglang/srt/layers/attention/nsa_backend.py
View file @
f235498e
...
...
@@ -47,7 +47,7 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
else
:
from
sgl_kernel.flash_attn
import
flash_attn_with_kvcache
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
@
dataclass
(
frozen
=
True
)
...
...
@@ -823,7 +823,23 @@ class NativeSparseAttnBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
# Do absorbed multi-latent attention
# Detect MHA mode: multi KV heads (vs MLA with single KV head)
is_mha_mode
=
(
layer
.
tp_k_head_num
==
layer
.
tp_q_head_num
)
and
(
layer
.
tp_k_head_num
>
1
)
# Use MHA kernel if in MHA_ONE_SHOT mode
if
is_mha_mode
and
k
is
not
None
and
v
is
not
None
and
q_rope
is
None
:
return
self
.
_forward_standard_mha
(
q
=
q
,
k
=
k
,
v
=
v
,
layer
=
layer
,
forward_batch
=
forward_batch
,
metadata
=
metadata
,
)
# Do absorbed multi-latent attention (MLA path)
assert
q_rope
is
not
None
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
...
@@ -1154,6 +1170,49 @@ class NativeSparseAttnBackend(AttentionBackend):
)
return
o
def
_forward_standard_mha
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
metadata
:
NSAMetadata
,
)
->
torch
.
Tensor
:
"""Standard MHA using FlashAttention varlen for MHA_ONE_SHOT mode."""
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
v
=
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
)
# MHA_ONE_SHOT: k/v include all tokens (prefix + current)
cu_seqlens_q
=
metadata
.
cu_seqlens_q
cu_seqlens_k
=
metadata
.
cu_seqlens_k
max_seqlen_k
=
metadata
.
max_seq_len_k
causal
=
True
# Verify batch sizes match (length of cu_seqlens should be batch_size + 1)
assert
len
(
cu_seqlens_q
)
==
len
(
cu_seqlens_k
),
(
f
"batch_size mismatch: cu_seqlens_q has
{
len
(
cu_seqlens_q
)
-
1
}
requests, "
f
"cu_seqlens_k has
{
len
(
cu_seqlens_k
)
-
1
}
requests"
)
# Determine FA version: FA3 for SM90 (Hopper), FA4 for SM100+ (Blackwell and beyond)
device_sm_major
=
torch
.
cuda
.
get_device_capability
()[
0
]
fa_version
=
4
if
device_sm_major
>=
10
else
3
return
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
layer
.
scaling
,
causal
=
causal
,
ver
=
fa_version
,
)
def
_forward_tilelang
(
self
,
q_all
:
torch
.
Tensor
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f235498e
...
...
@@ -398,6 +398,34 @@ def handle_attention_aiter(attn, forward_batch):
def
handle_attention_nsa
(
attn
,
forward_batch
):
"""
Select MHA or MLA based on sequence length for optimal performance.
- Decode: MLA (avoids per-token decompression)
- Prefill <= 2048: MHA (topk ineffective, MHA has lower FLOPs)
- Prefill > 2048: MLA (topk filtering reduces computation significantly)
TODO: B200 (SM100) MHA path is temporarily disabled due to FA4 gpqa accuracy issues.
"""
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
return
AttnForwardMethod
.
MLA
if
_is_extend_without_speculative
(
forward_batch
):
assert
forward_batch
.
seq_lens_cpu
is
not
None
max_kv_len
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
# B200 (SM100) is temporarily disabled for MHA due to FA4 accuracy issues
# Currently only H200 (SM90) with FA3 is allowed to use MHA path
is_hopper
=
_device_sm
==
90
if
max_kv_len
<=
attn
.
indexer
.
index_topk
and
is_hopper
:
# NSA backend uses varlen kernel which supports MHA_ONE_SHOT
# Check if total sequence length fits in chunk capacity
sum_seq_lens
=
sum
(
forward_batch
.
seq_lens_cpu
)
# Use MHA_ONE_SHOT for best performance
if
sum_seq_lens
<=
forward_batch
.
get_max_chunk_capacity
():
return
AttnForwardMethod
.
MHA_ONE_SHOT
return
AttnForwardMethod
.
MLA
...
...
@@ -1466,8 +1494,21 @@ class DeepseekV2AttentionMLA(nn.Module):
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
)
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
q_lora
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q_lora
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# NSA Indexer: cache quantized keys, auto-skip topk for sequences <= nsa_index_topk
if
self
.
use_nsa
and
_is_extend_without_speculative
(
forward_batch
):
_
=
self
.
indexer
(
x
=
hidden_states
,
q_lora
=
q_lora
,
positions
=
positions
,
forward_batch
=
forward_batch
,
layer_id
=
self
.
layer_id
,
return_indices
=
False
,
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment