Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8beb356f
Unverified
Commit
8beb356f
authored
Apr 17, 2025
by
fzyzcjy
Committed by
GitHub
Apr 17, 2025
Browse files
Refactor DeepSeek decoder layer branches (#5205)
parent
c776234b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
21 deletions
+47
-21
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+47
-21
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
8beb356f
...
...
@@ -18,7 +18,8 @@
import
logging
import
os
from
enum
import
IntEnum
,
auto
from
dataclasses
import
dataclass
from
enum
import
Enum
,
IntEnum
,
auto
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
...
...
@@ -28,6 +29,7 @@ from tqdm import tqdm
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
parallel_state
,
tensor_model_parallel_all_reduce
,
...
...
@@ -146,7 +148,7 @@ class DeepseekV2MLP(nn.Module):
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
forward_mode
:
Optional
[
ForwardMode
]
=
None
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
...
...
@@ -999,6 +1001,19 @@ class DeepseekV2AttentionMLA(nn.Module):
return
output
class
_FFNInputMode
(
Enum
):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED
=
auto
()
# The MLP sublayer requires all tokens as input
FULL
=
auto
()
@
dataclass
class
_DecoderLayerInfo
:
is_sparse
:
bool
ffn_input_mode
:
_FFNInputMode
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
def
__init__
(
...
...
@@ -1009,14 +1024,6 @@ class DeepseekV2DecoderLayer(nn.Module):
is_nextn
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
def
is_sparse_layer
(
l
:
int
):
return
(
config
.
n_routed_experts
is
not
None
and
l
>=
config
.
first_k_dense_replace
and
l
%
config
.
moe_layer_freq
==
0
)
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
...
...
@@ -1047,13 +1054,17 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
if
is_nextn
or
is_sparse_layer
(
layer_id
):
self
.
info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
,
is_nextn
=
is_nextn
)
previous_layer_info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
-
1
,
is_nextn
=
False
)
if
self
.
info
.
is_sparse
:
self
.
mlp
=
DeepseekV2MoE
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
self
.
is_sparse
=
True
else
:
self
.
mlp
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
...
...
@@ -1062,11 +1073,9 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
self
.
is_sparse
=
False
self
.
input_is_scattered
=
(
is_sparse_layer
(
layer_id
-
1
)
and
global_server_args_dict
[
"enable_deepep_moe"
]
previous_layer_info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
)
self
.
is_last_layer
=
self
.
layer_id
==
config
.
num_hidden_layers
-
1
...
...
@@ -1075,6 +1084,20 @@ class DeepseekV2DecoderLayer(nn.Module):
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
@
staticmethod
def
_compute_info
(
config
:
PretrainedConfig
,
layer_id
:
int
,
is_nextn
:
bool
):
is_sparse
=
is_nextn
or
(
config
.
n_routed_experts
is
not
None
and
layer_id
>=
config
.
first_k_dense_replace
and
layer_id
%
config
.
moe_layer_freq
==
0
)
ffn_input_mode
=
(
_FFNInputMode
.
SCATTERED
if
(
global_server_args_dict
[
"enable_deepep_moe"
]
and
is_sparse
)
else
_FFNInputMode
.
FULL
)
return
_DecoderLayerInfo
(
is_sparse
=
is_sparse
,
ffn_input_mode
=
ffn_input_mode
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -1082,16 +1105,18 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
global_server_args_dict
[
"enable_deepep_moe"
]
and
self
.
is_sparse
:
return
self
.
forward_
deepep
(
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
return
self
.
forward_
ffn_with_scattered_input
(
positions
,
hidden_states
,
forward_batch
,
residual
)
el
se
:
return
self
.
forward_
normal
(
el
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
return
self
.
forward_
ffn_with_full_input
(
positions
,
hidden_states
,
forward_batch
,
residual
)
else
:
raise
NotImplementedError
def
forward_
normal
(
def
forward_
ffn_with_full_input
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -1158,7 +1183,7 @@ class DeepseekV2DecoderLayer(nn.Module):
return
hidden_states
,
residual
def
forward_
deepep
(
def
forward_
ffn_with_scattered_input
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -1214,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
.
forward_mode
)
if
self
.
is_last_layer
and
self
.
attn_tp_size
!=
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment