Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8dc191f2
Unverified
Commit
8dc191f2
authored
May 17, 2025
by
fzyzcjy
Committed by
GitHub
May 16, 2025
Browse files
Fix one wasted kernel in DeepSeek and minor refactor (#6316)
parent
64825b83
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
23 deletions
+10
-23
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+10
-23
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
8dc191f2
...
@@ -1336,28 +1336,16 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1336,28 +1336,16 @@ class DeepseekV2DecoderLayer(nn.Module):
)
)
if
self
.
attn_tp_size
!=
1
:
if
self
.
attn_tp_size
!=
1
:
if
self
.
input_is_scattered
:
tensor_list
=
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
))
tensor_list
=
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
))
hidden_states
=
tensor_list
[
self
.
attn_tp_rank
]
hidden_states
=
tensor_list
[
self
.
attn_tp_rank
]
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
if
not
self
.
input_is_scattered
:
if
hidden_states
.
shape
[
0
]
!=
0
:
residual
=
residual
.
tensor_split
(
self
.
attn_tp_size
)[
self
.
attn_tp_rank
]
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
if
hidden_states
.
shape
[
0
]
!=
0
:
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
else
:
hidden_states
,
residual
if
self
.
attn_tp_rank
==
0
:
)
hidden_states
+=
residual
tensor_list
=
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
))
hidden_states
=
tensor_list
[
self
.
attn_tp_rank
]
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
residual
=
hidden_states
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
if
not
(
if
not
(
self
.
_enable_moe_dense_fully_dp
()
self
.
_enable_moe_dense_fully_dp
()
...
@@ -1859,7 +1847,6 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1859,7 +1847,6 @@ class DeepseekV2ForCausalLM(nn.Module):
q_a_proj_name
in
cached_a_proj
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
fused_weight
=
torch
.
cat
(
fused_weight
=
torch
.
cat
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment