Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8beb356f
"src/vscode:/vscode.git/clone" did not exist on "2a0c823527694058d410ed6f91b52e7dd9f94ebe"
Unverified
Commit
8beb356f
authored
Apr 17, 2025
by
fzyzcjy
Committed by
GitHub
Apr 17, 2025
Browse files
Refactor DeepSeek decoder layer branches (#5205)
parent
c776234b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
21 deletions
+47
-21
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+47
-21
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
8beb356f
...
@@ -18,7 +18,8 @@
...
@@ -18,7 +18,8 @@
import
logging
import
logging
import
os
import
os
from
enum
import
IntEnum
,
auto
from
dataclasses
import
dataclass
from
enum
import
Enum
,
IntEnum
,
auto
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -28,6 +29,7 @@ from tqdm import tqdm
...
@@ -28,6 +29,7 @@ from tqdm import tqdm
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
parallel_state
,
parallel_state
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
...
@@ -146,7 +148,7 @@ class DeepseekV2MLP(nn.Module):
...
@@ -146,7 +148,7 @@ class DeepseekV2MLP(nn.Module):
)
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
forward_mode
:
Optional
[
ForwardMode
]
=
None
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
...
@@ -999,6 +1001,19 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -999,6 +1001,19 @@ class DeepseekV2AttentionMLA(nn.Module):
return
output
return
output
class
_FFNInputMode
(
Enum
):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED
=
auto
()
# The MLP sublayer requires all tokens as input
FULL
=
auto
()
@
dataclass
class
_DecoderLayerInfo
:
is_sparse
:
bool
ffn_input_mode
:
_FFNInputMode
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -1009,14 +1024,6 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1009,14 +1024,6 @@ class DeepseekV2DecoderLayer(nn.Module):
is_nextn
:
bool
=
False
,
is_nextn
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
def
is_sparse_layer
(
l
:
int
):
return
(
config
.
n_routed_experts
is
not
None
and
l
>=
config
.
first_k_dense_replace
and
l
%
config
.
moe_layer_freq
==
0
)
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
...
@@ -1047,13 +1054,17 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1047,13 +1054,17 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
)
if
is_nextn
or
is_sparse_layer
(
layer_id
):
self
.
info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
,
is_nextn
=
is_nextn
)
previous_layer_info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
-
1
,
is_nextn
=
False
)
if
self
.
info
.
is_sparse
:
self
.
mlp
=
DeepseekV2MoE
(
self
.
mlp
=
DeepseekV2MoE
(
config
=
config
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
)
self
.
is_sparse
=
True
else
:
else
:
self
.
mlp
=
DeepseekV2MLP
(
self
.
mlp
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -1062,11 +1073,9 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1062,11 +1073,9 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
)
self
.
is_sparse
=
False
self
.
input_is_scattered
=
(
self
.
input_is_scattered
=
(
is_sparse_layer
(
layer_id
-
1
)
previous_layer_info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
and
global_server_args_dict
[
"enable_deepep_moe"
]
)
)
self
.
is_last_layer
=
self
.
layer_id
==
config
.
num_hidden_layers
-
1
self
.
is_last_layer
=
self
.
layer_id
==
config
.
num_hidden_layers
-
1
...
@@ -1075,6 +1084,20 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1075,6 +1084,20 @@ class DeepseekV2DecoderLayer(nn.Module):
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
)
@
staticmethod
def
_compute_info
(
config
:
PretrainedConfig
,
layer_id
:
int
,
is_nextn
:
bool
):
is_sparse
=
is_nextn
or
(
config
.
n_routed_experts
is
not
None
and
layer_id
>=
config
.
first_k_dense_replace
and
layer_id
%
config
.
moe_layer_freq
==
0
)
ffn_input_mode
=
(
_FFNInputMode
.
SCATTERED
if
(
global_server_args_dict
[
"enable_deepep_moe"
]
and
is_sparse
)
else
_FFNInputMode
.
FULL
)
return
_DecoderLayerInfo
(
is_sparse
=
is_sparse
,
ffn_input_mode
=
ffn_input_mode
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -1082,16 +1105,18 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1082,16 +1105,18 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
global_server_args_dict
[
"enable_deepep_moe"
]
and
self
.
is_sparse
:
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
return
self
.
forward_
deepep
(
return
self
.
forward_
ffn_with_scattered_input
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
)
el
se
:
el
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
return
self
.
forward_
normal
(
return
self
.
forward_
ffn_with_full_input
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
)
else
:
raise
NotImplementedError
def
forward_
normal
(
def
forward_
ffn_with_full_input
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -1158,7 +1183,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1158,7 +1183,7 @@ class DeepseekV2DecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
def
forward_
deepep
(
def
forward_
ffn_with_scattered_input
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -1214,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1214,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
hidden_states
,
residual
)
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
.
forward_mode
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
.
forward_mode
)
if
self
.
is_last_layer
and
self
.
attn_tp_size
!=
1
:
if
self
.
is_last_layer
and
self
.
attn_tp_size
!=
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment