Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
9ff9fa7f
"tests/vscode:/vscode.git/clone" did not exist on "aa7716be479f9d93cbf9c9d0a804f0444ece9bce"
Unverified
Commit
9ff9fa7f
authored
Oct 29, 2025
by
Trevor Morris
Committed by
GitHub
Oct 29, 2025
Browse files
Fuse wk and weight_proj in Indexer for DeepSeekV3.2-FP4 (#12094)
parent
7ed8ba05
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
110 additions
and
22 deletions
+110
-22
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
+45
-22
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+65
-0
No files found.
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
View file @
9ff9fa7f
...
...
@@ -119,6 +119,7 @@ class Indexer(CustomOp):
prefix
:
str
=
""
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
fuse_wk_and_weights_proj
:
bool
=
False
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -129,6 +130,7 @@ class Indexer(CustomOp):
self
.
q_lora_rank
=
q_lora_rank
self
.
layer_id
=
layer_id
self
.
alt_stream
=
alt_stream
self
.
fuse_wk_and_weights_proj
=
fuse_wk_and_weights_proj
if
is_cuda
():
self
.
sm_count
=
deep_gemm
.
get_num_sms
()
self
.
half_device_sm_count
=
align
(
self
.
sm_count
//
2
,
8
)
...
...
@@ -140,21 +142,29 @@ class Indexer(CustomOp):
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"wq_b"
,
prefix
),
)
self
.
wk
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"wk"
,
prefix
),
)
if
self
.
fuse_wk_and_weights_proj
:
self
.
fused_wk_and_weights_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
head_dim
+
self
.
n_heads
,
bias
=
False
,
prefix
=
add_prefix
(
"fused_wk_and_weights_proj"
,
prefix
),
)
else
:
self
.
wk
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"wk"
,
prefix
),
)
# NOTE: weight_proj is not quantized
self
.
weights_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
n_heads
,
bias
=
False
,
prefix
=
add_prefix
(
"weights_proj"
,
prefix
),
)
self
.
k_norm
=
V32LayerNorm
(
self
.
head_dim
)
# NOTE: weight_proj is not quantized
self
.
weights_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
n_heads
,
bias
=
False
,
prefix
=
add_prefix
(
"weights_proj"
,
prefix
),
)
self
.
rotary_emb
=
get_rope_wrapper
(
rope_head_dim
,
rotary_dim
=
rope_head_dim
,
...
...
@@ -169,8 +179,7 @@ class Indexer(CustomOp):
self
.
softmax_scale
=
self
.
head_dim
**-
0.5
@
torch
.
compile
(
dynamic
=
True
)
def
_get_logits_head_gate
(
self
,
x
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
weights
,
_
=
self
.
weights_proj
(
x
)
def
_get_logits_head_gate
(
self
,
weights
:
torch
.
Tensor
,
q_scale
:
torch
.
Tensor
):
weights
=
weights
*
self
.
n_heads
**-
0.5
weights
=
weights
.
unsqueeze
(
-
1
)
*
q_scale
*
self
.
softmax_scale
return
weights
...
...
@@ -182,7 +191,7 @@ class Indexer(CustomOp):
positions
:
torch
.
Tensor
,
enable_dual_stream
:
bool
,
):
weights
=
None
if
enable_dual_stream
:
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
...
...
@@ -199,7 +208,12 @@ class Indexer(CustomOp):
)
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# TODO we should also put DeepGEMM half SM here?
key
,
_
=
self
.
wk
(
x
)
if
self
.
fuse_wk_and_weights_proj
:
key
,
weights
=
self
.
fused_wk_and_weights_proj
(
x
)[
0
].
split
(
[
self
.
head_dim
,
self
.
n_heads
],
dim
=-
1
)
else
:
key
,
_
=
self
.
wk
(
x
)
key
=
self
.
k_norm
(
key
)
k_rope
,
_
=
torch
.
split
(
...
...
@@ -217,7 +231,12 @@ class Indexer(CustomOp):
query
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
)
key
,
_
=
self
.
wk
(
x
)
if
self
.
fuse_wk_and_weights_proj
:
key
,
weights
=
self
.
fused_wk_and_weights_proj
(
x
)[
0
].
split
(
[
self
.
head_dim
,
self
.
n_heads
],
dim
=-
1
)
else
:
key
,
_
=
self
.
wk
(
x
)
key
=
self
.
k_norm
(
key
)
k_rope
,
_
=
torch
.
split
(
key
,
[
self
.
rope_head_dim
,
self
.
head_dim
-
self
.
rope_head_dim
],
dim
=-
1
...
...
@@ -240,7 +259,7 @@ class Indexer(CustomOp):
query
=
rotate_activation
(
query
)
key
=
rotate_activation
(
key
)
return
query
,
key
return
query
,
key
,
weights
def
_get_topk_paged
(
self
,
...
...
@@ -490,7 +509,9 @@ class Indexer(CustomOp):
if
metadata
is
None
:
return
None
query
,
key
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
)
query
,
key
,
weights
=
self
.
_get_q_k_bf16
(
q_lora
,
x
,
positions
,
enable_dual_stream
)
if
enable_dual_stream
:
current_stream
=
torch
.
cuda
.
current_stream
()
...
...
@@ -517,7 +538,9 @@ class Indexer(CustomOp):
index_k_scale
=
k_scale
,
)
weights
=
self
.
_get_logits_head_gate
(
x
,
q_scale
)
if
not
self
.
fuse_wk_and_weights_proj
:
weights
,
_
=
self
.
weights_proj
(
x
)
weights
=
self
.
_get_logits_head_gate
(
weights
,
q_scale
)
if
is_cuda
():
assert
forward_batch
.
seq_lens_cpu
is
not
None
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
9ff9fa7f
...
...
@@ -224,6 +224,17 @@ def add_forward_absorb_core_attention_backend(backend_name):
logger
.
info
(
f
"Added
{
backend_name
}
to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS."
)
def
is_nsa_indexer_wk_and_weights_proj_fused
(
config
,
quant_config
):
"""
NSA Indexer wk and weights_proj can be fused in FP4 model because they are both in BF16
"""
return
(
is_deepseek_nsa
(
config
)
and
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"modelopt_fp4"
)
class
AttnForwardMethod
(
IntEnum
):
# Use multi-head attention
MHA
=
auto
()
...
...
@@ -1143,6 +1154,9 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config
=
quant_config
,
layer_id
=
layer_id
,
alt_stream
=
alt_stream
,
fuse_wk_and_weights_proj
=
is_nsa_indexer_wk_and_weights_proj_fused
(
config
,
quant_config
),
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
...
...
@@ -3413,6 +3427,10 @@ class DeepseekV2ForCausalLM(nn.Module):
self
.
config
.
q_lora_rank
is
not
None
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
fuse_wk_and_weights_proj
=
is_nsa_indexer_wk_and_weights_proj_fused
(
self
.
config
,
self
.
quant_config
)
cached_wk_and_weights_proj
=
{}
if
fuse_wk_and_weights_proj
else
None
if
is_nextn
:
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
...
...
@@ -3584,6 +3602,53 @@ class DeepseekV2ForCausalLM(nn.Module):
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
elif
fuse_wk_and_weights_proj
and
(
"wk"
in
name
or
"weights_proj"
in
name
):
cached_wk_and_weights_proj
[
name
]
=
loaded_weight
wk_name
=
(
name
if
"wk"
in
name
else
name
.
replace
(
"weights_proj"
,
"wk"
)
)
weights_proj_name
=
(
name
if
"weights_proj"
in
name
else
name
.
replace
(
"wk"
,
"weights_proj"
)
)
# When both wk and weights_proj has been cached, load the fused weight to parameter
if
(
wk_name
in
cached_wk_and_weights_proj
and
weights_proj_name
in
cached_wk_and_weights_proj
):
wk_weight
=
cached_wk_and_weights_proj
[
wk_name
]
weights_proj_weight
=
cached_wk_and_weights_proj
[
weights_proj_name
]
# todo dequantize wk for fp8
assert
wk_weight
.
dtype
==
weights_proj_weight
.
dtype
fused_weight
=
torch
.
cat
(
[
wk_weight
,
weights_proj_weight
],
dim
=
0
)
param_name
=
(
name
.
replace
(
"wk"
,
"fused_wk_and_weights_proj"
)
if
"wk"
in
name
else
name
.
replace
(
"weights_proj"
,
"fused_wk_and_weights_proj"
,
)
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
futures
.
append
(
executor
.
submit
(
weight_loader
,
param
,
fused_weight
)
)
cached_wk_and_weights_proj
.
pop
(
wk_name
)
cached_wk_and_weights_proj
.
pop
(
weights_proj_name
)
else
:
if
(
"k_scale"
in
name
or
"v_scale"
in
name
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment