Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
febe21ce
Unverified
Commit
febe21ce
authored
Apr 04, 2025
by
fzyzcjy
Committed by
GitHub
Apr 04, 2025
Browse files
Small refactor DeepEPDispatcher into subclasses (#4994)
parent
a995a773
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
309 additions
and
191 deletions
+309
-191
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+291
-162
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+18
-29
No files found.
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
febe21ce
...
@@ -23,7 +23,7 @@ _buffer_normal = None
...
@@ -23,7 +23,7 @@ _buffer_normal = None
_buffer_low_latency
=
None
_buffer_low_latency
=
None
def
get_buffer_normal
(
group
:
dist
.
ProcessGroup
,
hidden_bytes
:
int
):
def
_
get_buffer_normal
(
group
:
dist
.
ProcessGroup
,
hidden_bytes
:
int
):
"""
"""
Copy from DeepEP example usage in model inference prefilling.
Copy from DeepEP example usage in model inference prefilling.
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
...
@@ -53,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
...
@@ -53,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
return
_buffer_normal
return
_buffer_normal
def
get_buffer_low_latency
(
def
_
get_buffer_low_latency
(
group
:
dist
.
ProcessGroup
,
group
:
dist
.
ProcessGroup
,
num_max_dispatch_tokens_per_rank
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
hidden
:
int
,
hidden
:
int
,
...
@@ -85,24 +85,16 @@ def get_buffer_low_latency(
...
@@ -85,24 +85,16 @@ def get_buffer_low_latency(
return
_buffer_low_latency
return
_buffer_low_latency
class
DeepEPDispatcher
:
class
_DeepEPDispatcherImplBase
:
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
"""
def
__init__
(
def
__init__
(
self
,
self
,
group
:
torch
.
distributed
.
ProcessGroup
,
group
:
torch
.
distributed
.
ProcessGroup
,
router_topk
:
int
,
router_topk
:
int
,
permute_fusion
:
bool
=
False
,
permute_fusion
:
bool
,
num_experts
:
int
=
None
,
num_experts
:
int
,
num_local_experts
:
int
=
None
,
num_local_experts
:
int
,
hidden_size
:
int
=
None
,
hidden_size
:
int
,
params_dtype
:
torch
.
dtype
=
None
,
params_dtype
:
torch
.
dtype
,
deepep_mode
:
DeepEPMode
=
DeepEPMode
.
auto
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
):
):
if
not
use_deepep
:
if
not
use_deepep
:
raise
ImportError
(
raise
ImportError
(
...
@@ -119,63 +111,36 @@ class DeepEPDispatcher:
...
@@ -119,63 +111,36 @@ class DeepEPDispatcher:
self
.
params_dtype
=
params_dtype
self
.
params_dtype
=
params_dtype
self
.
params_bytes
=
2
self
.
params_bytes
=
2
self
.
deepep_mode
=
deepep_mode
self
.
handle
=
None
self
.
handle
=
None
if
self
.
deepep_mode
.
enable_normal
():
def
dispatch
(
self
.
buffer_normal
=
get_buffer_normal
(
self
.
group
,
self
.
hidden_size
*
self
.
params_bytes
)
self
.
async_finish
=
async_finish
self
.
src2dst
=
None
if
self
.
deepep_mode
.
enable_low_latency
():
"""
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
# TODO(ch-wan): allow users to set this value
self
.
num_max_dispatch_tokens_per_rank
=
128
self
.
buffer_low_latency
=
get_buffer_low_latency
(
self
.
group
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
hidden_size
,
self
.
num_experts
,
)
self
.
return_recv_hook
=
return_recv_hook
def
deepep_permute
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
fp8_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
topk_weights
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
num_experts
:
int
,
use_block_qu
an
t
:
bool
=
False
,
num_max_dispatch_tokens_per_r
an
k
:
int
,
):
):
reorder_topk_ids
,
self
.
src2dst
,
seg_indptr
=
deepep_run_moe_deep_preprocess
(
raise
NotImplementedError
topk_idx
,
self
.
num_experts
)
def
combine
(
num_total_tokens
=
reorder_topk_ids
.
numel
()
self
,
gateup_input
=
torch
.
empty
(
hidden_states
:
torch
.
Tensor
,
(
int
(
num_total_tokens
),
hidden_states
.
shape
[
1
]),
topk_idx
:
torch
.
Tensor
,
device
=
hidden_states
.
device
,
topk_weights
:
torch
.
Tensor
,
dtype
=
(
)
->
torch
.
Tensor
:
fp8_dtype
raise
NotImplementedError
if
(
use_fp8_w8a8
and
not
use_block_quant
)
else
hidden_states
.
dtype
),
class
_DeepEPDispatcherImplNormal
(
_DeepEPDispatcherImplBase
):
)
def
__init__
(
self
,
async_finish
:
bool
,
**
kwargs
):
# PreReorder
super
().
__init__
(
**
kwargs
)
deepep_permute_triton_kernel
[(
hidden_states
.
shape
[
0
],)](
hidden_states
,
self
.
buffer_normal
=
_get_buffer_normal
(
gateup_input
,
self
.
group
,
self
.
hidden_size
*
self
.
params_bytes
self
.
src2dst
,
topk_idx
,
None
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
)
return
reorder_topk_ids
,
seg_indptr
,
gateup_input
self
.
async_finish
=
async_finish
self
.
src2dst
=
None
def
dispatch
(
def
dispatch
(
self
,
self
,
...
@@ -183,51 +148,34 @@ class DeepEPDispatcher:
...
@@ -183,51 +148,34 @@ class DeepEPDispatcher:
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
=
128
,
num_max_dispatch_tokens_per_rank
:
int
,
forward_mode
:
ForwardMode
=
None
,
):
)
->
Tuple
:
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
reorder_topk_ids
=
torch
.
empty
(
(
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
hidden_states
,
)
topk_idx
,
seg_indptr
=
torch
.
zeros
(
topk_weights
,
(
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
event
,
)
)
=
self
.
_dispatch_normal
(
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
)
masked_m
=
torch
.
empty
(
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
(
self
.
num_local_experts
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
if
hidden_states
.
shape
[
0
]
>
0
:
)
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
_deepep_permute
(
expected_m
=
0
hidden_states
,
topk_idx
,
fp8_dtype
=
hidden_states
.
dtype
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
(
hidden_states
,
topk_idx
,
topk_weights
,
event
,
)
=
self
.
dispatch_normal
(
hidden_states
,
topk_idx
,
topk_weights
,
num_experts
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
if
hidden_states
.
shape
[
0
]
>
0
:
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
deepep_permute
(
hidden_states
,
topk_idx
,
fp8_dtype
=
hidden_states
.
dtype
)
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
expected_m
=
(
hidden_states
.
shape
[
0
]
*
self
.
buffer_low_latency
.
group_size
*
topk_idx
.
shape
[
1
]
+
num_experts
)
//
num_experts
hidden_states
,
masked_m
,
event
,
hook
=
self
.
dispatch_low_latency
(
hidden_states
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
=
True
,
)
)
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
else
:
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
reorder_topk_ids
=
torch
.
empty
(
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
seg_indptr
=
torch
.
zeros
(
(
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
# TODO
# masked_m = torch.empty(
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
# )
# expected_m = 0
masked_m
=
expected_m
=
None
return
(
return
(
hidden_states
,
hidden_states
,
...
@@ -239,7 +187,7 @@ class DeepEPDispatcher:
...
@@ -239,7 +187,7 @@ class DeepEPDispatcher:
expected_m
,
expected_m
,
)
)
def
dispatch_normal
(
def
_
dispatch_normal
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
@@ -292,7 +240,156 @@ class DeepEPDispatcher:
...
@@ -292,7 +240,156 @@ class DeepEPDispatcher:
event
,
event
,
)
)
def
dispatch_low_latency
(
def
_deepep_permute
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
fp8_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_block_quant
:
bool
=
False
,
):
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
"""
reorder_topk_ids
,
self
.
src2dst
,
seg_indptr
=
deepep_run_moe_deep_preprocess
(
topk_idx
,
self
.
num_experts
)
num_total_tokens
=
reorder_topk_ids
.
numel
()
gateup_input
=
torch
.
empty
(
(
int
(
num_total_tokens
),
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
(
fp8_dtype
if
(
use_fp8_w8a8
and
not
use_block_quant
)
else
hidden_states
.
dtype
),
)
# PreReorder
deepep_permute_triton_kernel
[(
hidden_states
.
shape
[
0
],)](
hidden_states
,
gateup_input
,
self
.
src2dst
,
topk_idx
,
None
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
return
reorder_topk_ids
,
seg_indptr
,
gateup_input
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
>
0
:
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
output
=
torch
.
empty
(
(
num_tokens
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
deepep_post_reorder_triton_kernel
[(
num_tokens
,)](
hidden_states
,
output
,
self
.
src2dst
,
topk_idx
,
topk_weights
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
else
:
output
=
torch
.
zeros
(
(
0
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
hidden_states
,
event
=
self
.
_combine_normal
(
output
,
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
return
hidden_states
def
_combine_normal
(
self
,
x
:
torch
.
Tensor
):
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
combined_x
,
_
,
event
=
self
.
buffer_normal
.
combine
(
x
,
self
.
handle
,
async_finish
=
self
.
async_finish
,
previous_event
=
previous_event
,
allocate_on_comm_stream
=
previous_event
is
not
None
,
)
return
combined_x
,
event
class
_DeepEPDispatcherImplLowLatency
(
_DeepEPDispatcherImplBase
):
def
__init__
(
self
,
return_recv_hook
:
bool
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
"""
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
# TODO(ch-wan): allow users to set this value
self
.
num_max_dispatch_tokens_per_rank
=
128
self
.
buffer_low_latency
=
_get_buffer_low_latency
(
self
.
group
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
hidden_size
,
self
.
num_experts
,
)
self
.
return_recv_hook
=
return_recv_hook
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
):
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
expected_m
=
(
hidden_states
.
shape
[
0
]
*
self
.
buffer_low_latency
.
group_size
*
topk_idx
.
shape
[
1
]
+
num_experts
)
//
num_experts
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_low_latency
(
hidden_states
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
=
True
,
)
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
# TODO
# reorder_topk_ids = torch.empty(
# (0,), device=hidden_states.device, dtype=torch.int64
# )
# seg_indptr = torch.zeros(
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
# )
reorder_topk_ids
=
seg_indptr
=
None
return
(
hidden_states
,
topk_idx
,
topk_weights
,
reorder_topk_ids
,
seg_indptr
,
masked_m
,
expected_m
,
)
def
_dispatch_low_latency
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
@@ -351,62 +448,17 @@ class DeepEPDispatcher:
...
@@ -351,62 +448,17 @@ class DeepEPDispatcher:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
hidden_states
,
event
,
hook
=
self
.
_combine_low_latency
(
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
hidden_states
,
if
hidden_states
.
shape
[
0
]
>
0
:
topk_idx
,
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
topk_weights
,
output
=
torch
.
empty
(
)
(
num_tokens
,
hidden_states
.
shape
[
1
]),
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
deepep_post_reorder_triton_kernel
[(
num_tokens
,)](
hidden_states
,
output
,
self
.
src2dst
,
topk_idx
,
topk_weights
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
else
:
output
=
torch
.
zeros
(
(
0
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
hidden_states
,
event
=
self
.
combine_normal
(
output
,
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
hidden_states
,
event
,
hook
=
self
.
combine_low_latency
(
hidden_states
,
topk_idx
,
topk_weights
,
)
hook
()
if
self
.
return_recv_hook
else
event
.
current_stream_wait
()
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
return
hidden_states
return
hidden_states
def
combine_normal
(
self
,
x
:
torch
.
Tensor
):
def
_combine_low_latency
(
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
combined_x
,
_
,
event
=
self
.
buffer_normal
.
combine
(
x
,
self
.
handle
,
async_finish
=
self
.
async_finish
,
previous_event
=
previous_event
,
allocate_on_comm_stream
=
previous_event
is
not
None
,
)
return
combined_x
,
event
def
combine_low_latency
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
...
@@ -423,3 +475,80 @@ class DeepEPDispatcher:
...
@@ -423,3 +475,80 @@ class DeepEPDispatcher:
)
)
)
)
return
combined_hidden_states
,
event
,
hook
return
combined_hidden_states
,
event
,
hook
class
DeepEPDispatcher
:
def
__init__
(
self
,
group
:
torch
.
distributed
.
ProcessGroup
,
router_topk
:
int
,
permute_fusion
:
bool
=
False
,
num_experts
:
int
=
None
,
num_local_experts
:
int
=
None
,
hidden_size
:
int
=
None
,
params_dtype
:
torch
.
dtype
=
None
,
deepep_mode
:
DeepEPMode
=
DeepEPMode
.
auto
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
):
self
.
deepep_mode
=
deepep_mode
common_kwargs
=
dict
(
group
=
group
,
router_topk
=
router_topk
,
permute_fusion
=
permute_fusion
,
num_experts
=
num_experts
,
num_local_experts
=
num_local_experts
,
hidden_size
=
hidden_size
,
params_dtype
=
params_dtype
,
)
if
self
.
deepep_mode
.
enable_normal
():
self
.
_normal_dispatcher
=
_DeepEPDispatcherImplNormal
(
async_finish
=
async_finish
,
**
common_kwargs
,
)
if
self
.
deepep_mode
.
enable_low_latency
():
self
.
_low_latency_dispatcher
=
_DeepEPDispatcherImplLowLatency
(
return_recv_hook
=
return_recv_hook
,
**
common_kwargs
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
=
128
,
forward_mode
:
ForwardMode
=
None
,
)
->
Tuple
:
return
self
.
_get_dispatcher
(
forward_mode
).
dispatch
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
num_experts
=
num_experts
,
num_max_dispatch_tokens_per_rank
=
num_max_dispatch_tokens_per_rank
,
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
)
->
torch
.
Tensor
:
return
self
.
_get_dispatcher
(
forward_mode
).
combine
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
)
def
_get_dispatcher
(
self
,
forward_mode
:
ForwardMode
)
->
_DeepEPDispatcherImplBase
:
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
forward_mode
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
return
self
.
_normal_dispatcher
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
return
self
.
_low_latency_dispatcher
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
python/sglang/srt/models/deepseek_v2.py
View file @
febe21ce
...
@@ -188,35 +188,24 @@ class DeepseekV2MoE(nn.Module):
...
@@ -188,35 +188,24 @@ class DeepseekV2MoE(nn.Module):
if
global_server_args_dict
[
"enable_deepep_moe"
]
if
global_server_args_dict
[
"enable_deepep_moe"
]
else
(
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
)
else
(
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
)
)
)
if
not
global_server_args_dict
[
"enable_deepep_moe"
]:
self
.
experts
=
MoEImpl
(
self
.
experts
=
MoEImpl
(
num_experts
=
config
.
n_routed_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
)
dict
(
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]])
else
:
if
global_server_args_dict
[
"enable_deepep_moe"
]
self
.
experts
=
MoEImpl
(
else
{}
num_experts
=
config
.
n_routed_experts
,
),
top_k
=
config
.
num_experts_per_tok
,
)
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]],
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment