Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f235498e
Unverified
Commit
f235498e
authored
Nov 05, 2025
by
YAMY
Committed by
GitHub
Nov 05, 2025
Browse files
DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill (#11892)
parent
149dc9aa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
188 additions
and
4 deletions
+188
-4
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+84
-0
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+61
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+43
-2
No files found.
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
f235498e
...
@@ -242,6 +242,30 @@ class Indexer(CustomOp):
...
@@ -242,6 +242,30 @@ class Indexer(CustomOp):
return
query
,
key
,
weights
return
query
,
key
,
weights
def
_get_k_bf16
(
self
,
x
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
enable_dual_stream
:
bool
,
):
# Compute only key, skip query and weights (weights is discarded if fused)
if
self
.
fuse_wk_and_weights_proj
:
key
,
_
=
self
.
fused_wk_and_weights_proj
(
x
)[
0
].
split
(
[
self
.
head_dim
,
self
.
n_heads
],
dim
=-
1
)
else
:
key
,
_
=
self
.
wk
(
x
)
key
=
self
.
k_norm
(
key
)
k_rope
,
_
=
torch
.
split
(
key
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
)
_
,
k_rope
=
self
.
rotary_emb
(
positions
,
k_rope
,
k_rope
)
key
[...,
:
self
.
rope_head_dim
]
=
k_rope
key
=
rotate_activation
(
key
)
return
key
def
_get_topk_paged
(
def
_get_topk_paged
(
self
,
self
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
...
@@ -375,6 +399,45 @@ class Indexer(CustomOp):
...
@@ -375,6 +399,45 @@ class Indexer(CustomOp):
topk_result
[:
offset
]
=
raw_topk_result
topk_result
[:
offset
]
=
raw_topk_result
return
topk_result
return
topk_result
def
_forward_cuda_k_only
(
self
,
x
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
act_quant
,
enable_dual_stream
:
bool
,
metadata
:
BaseIndexerMetadata
,
return_indices
:
bool
=
True
,
)
->
Optional
[
torch
.
Tensor
]:
# Fast path: only compute and store k cache, skip all q and weights ops
key
=
self
.
_get_k_bf16
(
x
,
positions
,
enable_dual_stream
)
k_fp8
,
k_scale
=
act_quant
(
key
,
self
.
block_size
,
self
.
scale_fmt
)
if
not
forward_batch
.
out_cache_loc
.
is_contiguous
():
forward_batch
.
out_cache_loc
=
forward_batch
.
out_cache_loc
.
contiguous
()
forward_batch
.
token_to_kv_pool
.
set_index_k_and_scale_buffer
(
layer_id
=
layer_id
,
loc
=
forward_batch
.
out_cache_loc
,
index_k
=
k_fp8
,
index_k_scale
=
k_scale
,
)
# MHA doesn't need topk_indices
if
not
return_indices
:
return
None
# MLA: use dummy logits with topk kernel's fast path to generate indices
# When length <= 2048, naive_topk_cuda directly generates [0,1,...,length-1,-1,...]
seq_lens_expanded
=
metadata
.
get_seqlens_expanded
()
dummy_logits
=
torch
.
zeros
(
seq_lens_expanded
.
shape
[
0
],
self
.
index_topk
,
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
return
metadata
.
topk_transform
(
dummy_logits
,
self
.
index_topk
)
def
forward_indexer
(
def
forward_indexer
(
self
,
self
,
q_fp8
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
...
@@ -465,6 +528,7 @@ class Indexer(CustomOp):
...
@@ -465,6 +528,7 @@ class Indexer(CustomOp):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
layer_id
:
int
,
layer_id
:
int
,
return_indices
:
bool
=
True
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
if
is_hip
():
if
is_hip
():
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
act_quant
from
sglang.srt.layers.attention.nsa.tilelang_kernel
import
act_quant
...
@@ -490,6 +554,26 @@ class Indexer(CustomOp):
...
@@ -490,6 +554,26 @@ class Indexer(CustomOp):
if
metadata
is
None
:
if
metadata
is
None
:
return
None
return
None
# Determine if should skip topk based on sequence length
should_skip
=
False
if
not
forward_batch
.
forward_mode
.
is_decode_or_idle
():
if
forward_batch
.
seq_lens_cpu
is
not
None
:
max_kv_len
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
should_skip
=
max_kv_len
<=
self
.
index_topk
# Optimization: fast path when skipping topk computation
if
should_skip
:
return
self
.
_forward_cuda_k_only
(
x
,
positions
,
forward_batch
,
layer_id
,
act_quant
,
enable_dual_stream
,
metadata
,
return_indices
,
)
query
,
key
,
weights
=
self
.
_get_q_k_bf16
(
query
,
key
,
weights
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
q_lora
,
x
,
positions
,
enable_dual_stream
)
)
...
...
python/sglang/srt/layers/attention/nsa_backend.py
View file @
f235498e
...
@@ -47,7 +47,7 @@ if _is_hip:
...
@@ -47,7 +47,7 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
)
else
:
else
:
from
sgl_kernel.flash_attn
import
flash_attn_with_kvcache
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -823,7 +823,23 @@ class NativeSparseAttnBackend(AttentionBackend):
...
@@ -823,7 +823,23 @@ class NativeSparseAttnBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
kwargs
=
{}
# Do absorbed multi-latent attention
# Detect MHA mode: multi KV heads (vs MLA with single KV head)
is_mha_mode
=
(
layer
.
tp_k_head_num
==
layer
.
tp_q_head_num
)
and
(
layer
.
tp_k_head_num
>
1
)
# Use MHA kernel if in MHA_ONE_SHOT mode
if
is_mha_mode
and
k
is
not
None
and
v
is
not
None
and
q_rope
is
None
:
return
self
.
_forward_standard_mha
(
q
=
q
,
k
=
k
,
v
=
v
,
layer
=
layer
,
forward_batch
=
forward_batch
,
metadata
=
metadata
,
)
# Do absorbed multi-latent attention (MLA path)
assert
q_rope
is
not
None
assert
q_rope
is
not
None
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
@@ -1154,6 +1170,49 @@ class NativeSparseAttnBackend(AttentionBackend):
...
@@ -1154,6 +1170,49 @@ class NativeSparseAttnBackend(AttentionBackend):
)
)
return
o
return
o
def
_forward_standard_mha
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
metadata
:
NSAMetadata
,
)
->
torch
.
Tensor
:
"""Standard MHA using FlashAttention varlen for MHA_ONE_SHOT mode."""
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
v
=
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
)
# MHA_ONE_SHOT: k/v include all tokens (prefix + current)
cu_seqlens_q
=
metadata
.
cu_seqlens_q
cu_seqlens_k
=
metadata
.
cu_seqlens_k
max_seqlen_k
=
metadata
.
max_seq_len_k
causal
=
True
# Verify batch sizes match (length of cu_seqlens should be batch_size + 1)
assert
len
(
cu_seqlens_q
)
==
len
(
cu_seqlens_k
),
(
f
"batch_size mismatch: cu_seqlens_q has
{
len
(
cu_seqlens_q
)
-
1
}
requests, "
f
"cu_seqlens_k has
{
len
(
cu_seqlens_k
)
-
1
}
requests"
)
# Determine FA version: FA3 for SM90 (Hopper), FA4 for SM100+ (Blackwell and beyond)
device_sm_major
=
torch
.
cuda
.
get_device_capability
()[
0
]
fa_version
=
4
if
device_sm_major
>=
10
else
3
return
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
layer
.
scaling
,
causal
=
causal
,
ver
=
fa_version
,
)
def
_forward_tilelang
(
def
_forward_tilelang
(
self
,
self
,
q_all
:
torch
.
Tensor
,
q_all
:
torch
.
Tensor
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f235498e
...
@@ -398,6 +398,34 @@ def handle_attention_aiter(attn, forward_batch):
...
@@ -398,6 +398,34 @@ def handle_attention_aiter(attn, forward_batch):
def
handle_attention_nsa
(
attn
,
forward_batch
):
def
handle_attention_nsa
(
attn
,
forward_batch
):
"""
Select MHA or MLA based on sequence length for optimal performance.
- Decode: MLA (avoids per-token decompression)
- Prefill <= 2048: MHA (topk ineffective, MHA has lower FLOPs)
- Prefill > 2048: MLA (topk filtering reduces computation significantly)
TODO: B200 (SM100) MHA path is temporarily disabled due to FA4 gpqa accuracy issues.
"""
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
return
AttnForwardMethod
.
MLA
if
_is_extend_without_speculative
(
forward_batch
):
assert
forward_batch
.
seq_lens_cpu
is
not
None
max_kv_len
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
# B200 (SM100) is temporarily disabled for MHA due to FA4 accuracy issues
# Currently only H200 (SM90) with FA3 is allowed to use MHA path
is_hopper
=
_device_sm
==
90
if
max_kv_len
<=
attn
.
indexer
.
index_topk
and
is_hopper
:
# NSA backend uses varlen kernel which supports MHA_ONE_SHOT
# Check if total sequence length fits in chunk capacity
sum_seq_lens
=
sum
(
forward_batch
.
seq_lens_cpu
)
# Use MHA_ONE_SHOT for best performance
if
sum_seq_lens
<=
forward_batch
.
get_max_chunk_capacity
():
return
AttnForwardMethod
.
MHA_ONE_SHOT
return
AttnForwardMethod
.
MLA
return
AttnForwardMethod
.
MLA
...
@@ -1466,8 +1494,21 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1466,8 +1494,21 @@ class DeepseekV2AttentionMLA(nn.Module):
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
)
)
q
=
self
.
q_a_layernorm
(
q
)
q_lora
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
q
=
self
.
q_b_proj
(
q_lora
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
# NSA Indexer: cache quantized keys, auto-skip topk for sequences <= nsa_index_topk
if
self
.
use_nsa
and
_is_extend_without_speculative
(
forward_batch
):
_
=
self
.
indexer
(
x
=
hidden_states
,
q_lora
=
q_lora
,
positions
=
positions
,
forward_batch
=
forward_batch
,
layer_id
=
self
.
layer_id
,
return_indices
=
False
,
)
else
:
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment