Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
d8ab6011
"docs/vscode:/vscode.git/clone" did not exist on "38d11b829f79e6dce149e95dfd89d5ada2ee197a"
Unverified
Commit
d8ab6011
authored
May 03, 2025
by
Ke Bao
Committed by
GitHub
May 02, 2025
Browse files
Overlap qk norm with two streams (#5977)
parent
6579cd7d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
6 deletions
+26
-6
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+26
-6
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
d8ab6011
...
@@ -421,6 +421,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -421,6 +421,7 @@ class DeepseekV2AttentionMLA(nn.Module):
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
layer_id
:
int
=
None
,
layer_id
:
int
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
...
@@ -543,6 +544,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -543,6 +544,8 @@ class DeepseekV2AttentionMLA(nn.Module):
prefix
=
add_prefix
(
"attn_mha"
,
prefix
),
prefix
=
add_prefix
(
"attn_mha"
,
prefix
),
)
)
self
.
alt_stream
=
alt_stream
self
.
w_kc
=
None
self
.
w_kc
=
None
self
.
w_vc
=
None
self
.
w_vc
=
None
self
.
w_scale
=
None
self
.
w_scale
=
None
...
@@ -706,14 +709,32 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -706,14 +709,32 @@ class DeepseekV2AttentionMLA(nn.Module):
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
)
)
q
=
self
.
q_a_layernorm
(
q
)
k_nope
=
latent_cache
[...,
:
self
.
kv_lora_rank
]
# overlap qk norm
if
self
.
alt_stream
is
not
None
and
torch
.
cuda
.
is_current_stream_capturing
():
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
q
=
self
.
q_a_layernorm
(
q
)
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
k_nope
=
self
.
kv_a_layernorm
(
k_nope
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
else
:
q
=
self
.
q_a_layernorm
(
q
)
k_nope
=
self
.
kv_a_layernorm
(
k_nope
)
k_nope
=
k_nope
.
unsqueeze
(
1
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
k_nope
=
latent_cache
[...,
:
self
.
kv_lora_rank
]
k_nope
=
self
.
kv_a_layernorm
(
k_nope
).
unsqueeze
(
1
)
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_pe
=
latent_cache
[...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
if
self
.
use_deep_gemm_bmm
:
if
self
.
use_deep_gemm_bmm
:
q_nope_val
,
q_nope_scale
,
masked_m
,
expected_m
,
aligned_m
=
(
q_nope_val
,
q_nope_scale
,
masked_m
,
expected_m
,
aligned_m
=
(
...
@@ -750,11 +771,6 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -750,11 +771,6 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
torch
.
bmm
(
q_nope
.
transpose
(
0
,
1
),
self
.
w_kc
)
q_nope_out
=
torch
.
bmm
(
q_nope
.
transpose
(
0
,
1
),
self
.
w_kc
)
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
k_nope
=
latent_cache
[...,
:
self
.
kv_lora_rank
]
k_nope
=
self
.
kv_a_layernorm
(
k_nope
).
unsqueeze
(
1
)
k_pe
=
latent_cache
[...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
if
self
.
attention_backend
==
"fa3"
:
if
self
.
attention_backend
==
"fa3"
:
...
@@ -1104,6 +1120,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1104,6 +1120,7 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
is_nextn
:
bool
=
False
,
is_nextn
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -1133,6 +1150,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1133,6 +1150,7 @@ class DeepseekV2DecoderLayer(nn.Module):
layer_id
=
layer_id
,
layer_id
=
layer_id
,
reduce_results
=
False
,
reduce_results
=
False
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
alt_stream
=
alt_stream
,
)
)
self
.
info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
,
is_nextn
=
is_nextn
)
self
.
info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
,
is_nextn
=
is_nextn
)
...
@@ -1376,6 +1394,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -1376,6 +1394,7 @@ class DeepseekV2Model(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
)
)
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
[
DeepseekV2DecoderLayer
(
DeepseekV2DecoderLayer
(
...
@@ -1383,6 +1402,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -1383,6 +1402,7 @@ class DeepseekV2Model(nn.Module):
layer_id
,
layer_id
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
f
"layers.
{
layer_id
}
"
,
prefix
),
prefix
=
add_prefix
(
f
"layers.
{
layer_id
}
"
,
prefix
),
alt_stream
=
self
.
alt_stream
,
)
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment