Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9ff9fa7f
Unverified
Commit
9ff9fa7f
authored
Oct 29, 2025
by
Trevor Morris
Committed by
GitHub
Oct 29, 2025
Browse files
Fuse wk and weight_proj in Indexer for DeepSeekV3.2-FP4 (#12094)
parent
7ed8ba05
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
110 additions
and
22 deletions
+110
-22
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+45
-22
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+65
-0
No files found.
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
9ff9fa7f
...
@@ -119,6 +119,7 @@ class Indexer(CustomOp):
...
@@ -119,6 +119,7 @@ class Indexer(CustomOp):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
fuse_wk_and_weights_proj
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -129,6 +130,7 @@ class Indexer(CustomOp):
...
@@ -129,6 +130,7 @@ class Indexer(CustomOp):
self
.
q_lora_rank
=
q_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
alt_stream
=
alt_stream
self
.
alt_stream
=
alt_stream
self
.
fuse_wk_and_weights_proj
=
fuse_wk_and_weights_proj
if
is_cuda
():
if
is_cuda
():
self
.
sm_count
=
deep_gemm
.
get_num_sms
()
self
.
sm_count
=
deep_gemm
.
get_num_sms
()
self
.
half_device_sm_count
=
align
(
self
.
sm_count
//
2
,
8
)
self
.
half_device_sm_count
=
align
(
self
.
sm_count
//
2
,
8
)
...
@@ -140,21 +142,29 @@ class Indexer(CustomOp):
...
@@ -140,21 +142,29 @@ class Indexer(CustomOp):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"wq_b"
,
prefix
),
prefix
=
add_prefix
(
"wq_b"
,
prefix
),
)
)
self
.
wk
=
ReplicatedLinear
(
if
self
.
fuse_wk_and_weights_proj
:
self
.
hidden_size
,
self
.
fused_wk_and_weights_proj
=
ReplicatedLinear
(
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
self
.
head_dim
+
self
.
n_heads
,
quant_config
=
quant_config
,
bias
=
False
,
prefix
=
add_prefix
(
"wk"
,
prefix
),
prefix
=
add_prefix
(
"fused_wk_and_weights_proj"
,
prefix
),
)
)
else
:
self
.
wk
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"wk"
,
prefix
),
)
# NOTE: weight_proj is not quantized
self
.
weights_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
n_heads
,
bias
=
False
,
prefix
=
add_prefix
(
"weights_proj"
,
prefix
),
)
self
.
k_norm
=
V32LayerNorm
(
self
.
head_dim
)
self
.
k_norm
=
V32LayerNorm
(
self
.
head_dim
)
# NOTE: weight_proj is not quantized
self
.
weights_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
n_heads
,
bias
=
False
,
prefix
=
add_prefix
(
"weights_proj"
,
prefix
),
)
self
.
rotary_emb
=
get_rope_wrapper
(
self
.
rotary_emb
=
get_rope_wrapper
(
rope_head_dim
,
rope_head_dim
,
rotary_dim
=
rope_head_dim
,
rotary_dim
=
rope_head_dim
,
...
@@ -169,8 +179,7 @@ class Indexer(CustomOp):
...
@@ -169,8 +179,7 @@ class Indexer(CustomOp):
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
)
def
_get_logits_head_gate
(
self
,
x
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
def
_get_logits_head_gate
(
self
,
weights
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
weights
,
_
=
self
.
weights_proj
(
x
)
weights
=
weights
*
self
.
n_heads
**-
0.5
weights
=
weights
*
self
.
n_heads
**-
0.5
weights
=
weights
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
weights
=
weights
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
return
weights
return
weights
...
@@ -182,7 +191,7 @@ class Indexer(CustomOp):
...
@@ -182,7 +191,7 @@ class Indexer(CustomOp):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
enable_dual_stream
:
bool
,
enable_dual_stream
:
bool
,
):
):
weights
=
None
if
enable_dual_stream
:
if
enable_dual_stream
:
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
self
.
alt_stream
.
wait_stream
(
current_stream
)
...
@@ -199,7 +208,12 @@ class Indexer(CustomOp):
...
@@ -199,7 +208,12 @@ class Indexer(CustomOp):
)
)
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# TODO we should also put DeepGEMM half SM here?
# TODO we should also put DeepGEMM half SM here?
key
,
_
=
self
.
wk
(
x
)
if
self
.
fuse_wk_and_weights_proj
:
key
,
weights
=
self
.
fused_wk_and_weights_proj
(
x
)[
0
].
split
(
[
self
.
head_dim
,
self
.
n_heads
],
dim
=-
1
)
else
:
key
,
_
=
self
.
wk
(
x
)
key
=
self
.
k_norm
(
key
)
key
=
self
.
k_norm
(
key
)
k_rope
,
_
=
torch
.
split
(
k_rope
,
_
=
torch
.
split
(
...
@@ -217,7 +231,12 @@ class Indexer(CustomOp):
...
@@ -217,7 +231,12 @@ class Indexer(CustomOp):
query
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
query
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
)
)
key
,
_
=
self
.
wk
(
x
)
if
self
.
fuse_wk_and_weights_proj
:
key
,
weights
=
self
.
fused_wk_and_weights_proj
(
x
)[
0
].
split
(
[
self
.
head_dim
,
self
.
n_heads
],
dim
=-
1
)
else
:
key
,
_
=
self
.
wk
(
x
)
key
=
self
.
k_norm
(
key
)
key
=
self
.
k_norm
(
key
)
k_rope
,
_
=
torch
.
split
(
k_rope
,
_
=
torch
.
split
(
key
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
key
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
...
@@ -240,7 +259,7 @@ class Indexer(CustomOp):
...
@@ -240,7 +259,7 @@ class Indexer(CustomOp):
query
=
rotate_activation
(
query
)
query
=
rotate_activation
(
query
)
key
=
rotate_activation
(
key
)
key
=
rotate_activation
(
key
)
return
query
,
key
return
query
,
key
,
weights
def
_get_topk_paged
(
def
_get_topk_paged
(
self
,
self
,
...
@@ -490,7 +509,9 @@ class Indexer(CustomOp):
...
@@ -490,7 +509,9 @@ class Indexer(CustomOp):
if
metadata
is
None
:
if
metadata
is
None
:
return
None
return
None
query
,
key
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
)
query
,
key
,
weights
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
)
if
enable_dual_stream
:
if
enable_dual_stream
:
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
...
@@ -517,7 +538,9 @@ class Indexer(CustomOp):
...
@@ -517,7 +538,9 @@ class Indexer(CustomOp):
index_k_scale
=
k_scale
,
index_k_scale
=
k_scale
,
)
)
weights
=
self
.
_get_logits_head_gate
(
x
,
q_scale
)
if
not
self
.
fuse_wk_and_weights_proj
:
weights
,
_
=
self
.
weights_proj
(
x
)
weights
=
self
.
_get_logits_head_gate
(
weights
,
q_scale
)
if
is_cuda
():
if
is_cuda
():
assert
forward_batch
.
seq_lens_cpu
is
not
None
assert
forward_batch
.
seq_lens_cpu
is
not
None
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
9ff9fa7f
...
@@ -224,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name):
...
@@ -224,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name):
logger
.
info
(
f
"Added
{
backend_name
}
to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS."
)
logger
.
info
(
f
"Added
{
backend_name
}
to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS."
)
def
is_nsa_indexer_wk_and_weights_proj_fused
(
config
,
quant_config
):
"""
NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
"""
return
(
is_deepseek_nsa
(
config
)
and
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
)
class
AttnForwardMethod
(
IntEnum
):
class
AttnForwardMethod
(
IntEnum
):
# Use multi-head attention
# Use multi-head attention
MHA
=
auto
()
MHA
=
auto
()
...
@@ -1143,6 +1154,9 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1143,6 +1154,9 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
alt_stream
=
alt_stream
,
alt_stream
=
alt_stream
,
fuse_wk_and_weights_proj
=
is_nsa_indexer_wk_and_weights_proj_fused
(
config
,
quant_config
),
)
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_b_proj
=
ColumnParallelLinear
(
...
@@ -3413,6 +3427,10 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3413,6 +3427,10 @@ class DeepseekV2ForCausalLM(nn.Module):
self
.
config
.
q_lora_rank
is
not
None
self
.
config
.
q_lora_rank
is
not
None
)
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
fuse_wk_and_weights_proj
=
is_nsa_indexer_wk_and_weights_proj_fused
(
self
.
config
,
self
.
quant_config
)
cached_wk_and_weights_proj
=
{}
if
fuse_wk_and_weights_proj
else
None
if
is_nextn
:
if
is_nextn
:
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
...
@@ -3584,6 +3602,53 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -3584,6 +3602,53 @@ class DeepseekV2ForCausalLM(nn.Module):
)
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
elif
fuse_wk_and_weights_proj
and
(
"wk"
in
name
or
"weights_proj"
in
name
):
cached_wk_and_weights_proj
[
name
]
=
loaded_weight
wk_name
=
(
name
if
"wk"
in
name
else
name
.
replace
(
"weights_proj"
,
"wk"
)
)
weights_proj_name
=
(
name
if
"weights_proj"
in
name
else
name
.
replace
(
"wk"
,
"weights_proj"
)
)
# When both wk and weights_proj has been cached, load the fused weight to parameter
if
(
wk_name
in
cached_wk_and_weights_proj
and
weights_proj_name
in
cached_wk_and_weights_proj
):
wk_weight
=
cached_wk_and_weights_proj
[
wk_name
]
weights_proj_weight
=
cached_wk_and_weights_proj
[
weights_proj_name
]
# todo dequantize wk for fp8
assert
wk_weight
.
dtype
==
weights_proj_weight
.
dtype
fused_weight
=
torch
.
cat
(
[
wk_weight
,
weights_proj_weight
],
dim
=
0
)
param_name
=
(
name
.
replace
(
"wk"
,
"fused_wk_and_weights_proj"
)
if
"wk"
in
name
else
name
.
replace
(
"weights_proj"
,
"fused_wk_and_weights_proj"
,
)
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
fused_weight
)
)
cached_wk_and_weights_proj
.
pop
(
wk_name
)
cached_wk_and_weights_proj
.
pop
(
weights_proj_name
)
else
:
else
:
if
(
if
(
"k_scale"
in
name
or
"v_scale"
in
name
"k_scale"
in
name
or
"v_scale"
in
name
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment