Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3f87f831
Unverified
Commit
3f87f831
authored
Apr 22, 2025
by
Baizhou Zhang
Committed by
GitHub
Apr 22, 2025
Browse files
Fuse q_a_proj and kv_a_proj (#5619)
parent
ce5412b6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
78 additions
and
25 deletions
+78
-25
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+78
-25
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
3f87f831
...
...
@@ -443,12 +443,12 @@ class DeepseekV2AttentionMLA(nn.Module):
# For tensor parallel attention
if
self
.
q_lora_rank
is
not
None
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
fused_qkv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
self
.
q_lora_rank
+
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"
q_a_proj
"
,
prefix
),
prefix
=
add_prefix
(
"
fused_qkv_a_proj_with_mqa
"
,
prefix
),
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
q_b_proj
=
ColumnParallelLinear
(
...
...
@@ -470,6 +470,14 @@ class DeepseekV2AttentionMLA(nn.Module):
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"kv_a_proj_with_mqa"
,
prefix
),
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
...
...
@@ -490,14 +498,6 @@ class DeepseekV2AttentionMLA(nn.Module):
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"kv_a_proj_with_mqa"
,
prefix
),
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
if
rope_scaling
:
...
...
@@ -656,15 +656,18 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
)
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
_
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
_
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_a
,
_
=
latent_cache
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
.
contiguous
())
...
...
@@ -699,13 +702,16 @@ class DeepseekV2AttentionMLA(nn.Module):
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
)
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
self
.
use_deep_gemm_bmm
:
...
...
@@ -744,7 +750,6 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
k_nope
=
latent_cache
[...,
:
self
.
kv_lora_rank
]
k_nope
=
self
.
kv_a_layernorm
(
k_nope
).
unsqueeze
(
1
)
k_pe
=
latent_cache
[...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
...
...
@@ -819,13 +824,16 @@ class DeepseekV2AttentionMLA(nn.Module):
q_len
,
self
.
num_local_heads
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
)
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
)
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
if
self
.
w_kc
.
dtype
==
torch
.
float8_e4m3fnuz
:
...
...
@@ -846,8 +854,6 @@ class DeepseekV2AttentionMLA(nn.Module):
else
:
q_nope_out
=
torch
.
bmm
(
q_nope
.
transpose
(
0
,
1
),
self
.
w_kc
)
q_input
[...,
:
self
.
kv_lora_rank
]
=
q_nope_out
.
transpose
(
0
,
1
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
v_input
=
latent_cache
[...,
:
self
.
kv_lora_rank
]
v_input
=
self
.
kv_a_layernorm
(
v_input
.
contiguous
()).
unsqueeze
(
1
)
k_input
=
latent_cache
.
unsqueeze
(
1
)
...
...
@@ -1018,15 +1024,17 @@ class DeepseekV2AttentionMLA(nn.Module):
# First do normal mha forward to get output for extended part
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
)
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
_
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
_
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_a
,
_
=
latent_cache
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
.
contiguous
())
...
...
@@ -1668,6 +1676,12 @@ class DeepseekV2ForCausalLM(nn.Module):
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
n_share_experts_fusion
,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj
=
hasattr
(
self
.
config
,
"q_lora_rank"
)
and
(
self
.
config
.
q_lora_rank
is
not
None
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
# TODO(HandH1998): Modify it when nextn is supported.
...
...
@@ -1723,6 +1737,45 @@ class DeepseekV2ForCausalLM(nn.Module):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
fuse_qkv_a_proj
and
(
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
):
cached_a_proj
[
name
]
=
loaded_weight
q_a_proj_name
=
(
name
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
0
)
param_name
=
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
fused_weight
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment