Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9c138a04
"docs/vscode:/vscode.git/clone" did not exist on "d50ce994213a264dfb746cd5e4ebc0f148f03b17"
Unverified
Commit
9c138a04
authored
Jul 28, 2025
by
Cheng Wan
Committed by
GitHub
Jul 28, 2025
Browse files
[3/N] MoE Refactor: Simplify DeepEP Output (#8421)
parent
c8f549d9
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
319 additions
and
276 deletions
+319
-276
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+150
-30
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+69
-118
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
+0
-0
python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py
...sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py
+48
-0
python/sglang/srt/layers/moe/token_dispatcher/standard.py
python/sglang/srt/layers/moe/token_dispatcher/standard.py
+19
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+13
-56
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+12
-69
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+8
-3
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
9c138a04
from
__future__
import
annotations
import
logging
import
logging
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -50,6 +52,13 @@ from sglang.srt.utils import (
...
@@ -50,6 +52,13 @@ from sglang.srt.utils import (
next_power_of_2
,
next_power_of_2
,
)
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
(
DeepEPLLOutput
,
DeepEPNormalOutput
,
DispatchOutput
,
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
...
@@ -797,6 +806,24 @@ class DeepEPMoE(EPMoE):
...
@@ -797,6 +806,24 @@ class DeepEPMoE(EPMoE):
"alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
"alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
)
)
# TODO: move to the beginning of the file
from
sglang.srt.distributed.parallel_state
import
get_tp_group
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
self
.
deepep_dispatcher
=
MaybeTboDeepEPDispatcher
(
group
=
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
num_experts
=
self
.
num_experts
,
num_local_experts
=
self
.
num_local_experts
,
hidden_size
=
hidden_size
,
params_dtype
=
params_dtype
,
deepep_mode
=
deepep_mode
,
async_finish
=
True
,
# TODO
return_recv_hook
=
True
,
)
if
self
.
deepep_mode
.
enable_low_latency
():
if
self
.
deepep_mode
.
enable_low_latency
():
assert
(
assert
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
...
@@ -837,37 +864,128 @@ class DeepEPMoE(EPMoE):
...
@@ -837,37 +864,128 @@ class DeepEPMoE(EPMoE):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
reorder_topk_ids
:
torch
.
Tensor
,
seg_indptr
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m
:
int
,
num_recv_tokens_per_expert
:
List
[
int
],
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
):
):
dispatch_output
=
self
.
dispatch
(
hidden_states
,
topk_idx
,
topk_weights
,
forward_batch
)
hidden_states
=
self
.
moe_impl
(
dispatch_output
)
hidden_states
=
self
.
combine
(
hidden_states
,
dispatch_output
.
topk_idx
,
dispatch_output
.
topk_weights
,
forward_batch
,
)
return
hidden_states
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
return
self
.
deepep_dispatcher
.
dispatch
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
)
def
moe_impl
(
self
,
dispatch_output
:
DispatchOutput
):
if
_use_aiter
:
if
_use_aiter
:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return
self
.
forward_aiter
(
hidden_states
,
topk_idx
,
topk_weights
)
return
self
.
forward_aiter
(
dispatch_output
)
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
if
dispatch_output
.
format
.
is_deepep_normal
():
forward_batch
.
is_extend_in_batch
)
if
resolved_deepep_mode
==
DeepEPMode
.
normal
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
return
self
.
forward_deepgemm_contiguous
(
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
hidden_states
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
)
else
:
else
:
return
self
.
forward_normal
(
hidden_states
,
reorder_topk_ids
,
seg_indptr
)
return
self
.
forward_normal
(
dispatch_output
)
elif
resolved_deepep_mode
==
DeepEPMode
.
low_latency
:
elif
dispatch_output
.
format
.
is_deepep_ll
()
:
return
self
.
forward_deepgemm_masked
(
hidden_states
,
masked_m
,
expected_m
)
return
self
.
forward_deepgemm_masked
(
dispatch_output
)
else
:
else
:
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
raise
ValueError
(
f
"Invalid deepep_mode:
{
self
.
deepep_mode
}
"
)
def
forward_normal
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
return
self
.
deepep_dispatcher
.
combine
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
)
def
_prepare_for_normal
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
reorder_topk_ids
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
seg_indptr
:
torch
.
Tensor
,
):
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
deepep_permute_triton_kernel
,
deepep_run_moe_deep_preprocess
,
)
if
hidden_states
.
shape
[
0
]
==
0
:
reorder_topk_ids
=
torch
.
empty
(
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
seg_indptr
=
torch
.
zeros
(
(
self
.
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
,
)
return
reorder_topk_ids
,
seg_indptr
,
hidden_states
else
:
if
_use_aiter
:
# skip permutation here as aiter fused_moe has fused inside
reorder_topk_ids
=
torch
.
empty
(
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
seg_indptr
=
torch
.
zeros
(
(
self
.
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
,
)
return
reorder_topk_ids
,
seg_indptr
,
hidden_states
reorder_topk_ids
,
self
.
src2dst
,
seg_indptr
=
deepep_run_moe_deep_preprocess
(
topk_idx
,
self
.
num_experts
)
num_total_tokens
=
reorder_topk_ids
.
numel
()
gateup_input
=
torch
.
empty
(
(
int
(
num_total_tokens
),
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
# PreReorder
deepep_permute_triton_kernel
[(
hidden_states
.
shape
[
0
],)](
hidden_states
,
gateup_input
,
self
.
src2dst
,
topk_idx
,
None
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
return
reorder_topk_ids
,
seg_indptr
,
gateup_input
def
forward_normal
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
):
hidden_states
,
topk_idx
=
(
dispatch_output
.
hidden_states
,
dispatch_output
.
topk_idx
,
)
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
_prepare_for_normal
(
hidden_states
,
topk_idx
)
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_device
=
hidden_states
.
device
hidden_states_device
=
hidden_states
.
device
...
@@ -983,10 +1101,13 @@ class DeepEPMoE(EPMoE):
...
@@ -983,10 +1101,13 @@ class DeepEPMoE(EPMoE):
def
forward_aiter
(
def
forward_aiter
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
dispatch_output
:
DeepEPNormalOutput
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
hidden_states
,
topk_idx
,
topk_weights
=
(
dispatch_output
.
hidden_states
,
dispatch_output
.
topk_idx
,
dispatch_output
.
topk_weights
,
)
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
return
hidden_states
return
hidden_states
# in original deepep, idx == -1 meaning invalid and will not be processed.
# in original deepep, idx == -1 meaning invalid and will not be processed.
...
@@ -1014,11 +1135,11 @@ class DeepEPMoE(EPMoE):
...
@@ -1014,11 +1135,11 @@ class DeepEPMoE(EPMoE):
def
forward_deepgemm_contiguous
(
def
forward_deepgemm_contiguous
(
self
,
self
,
hidden_states_fp8
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
dispatch_output
:
DeepEPNormalOutput
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
:
List
[
int
],
):
):
hidden_states_fp8
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
(
dispatch_output
)
hidden_states_fp8
,
hidden_states_scale
=
hidden_states_fp8
hidden_states_fp8
,
hidden_states_scale
=
hidden_states_fp8
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
assert
self
.
activation
==
"silu"
...
@@ -1138,10 +1259,9 @@ class DeepEPMoE(EPMoE):
...
@@ -1138,10 +1259,9 @@ class DeepEPMoE(EPMoE):
def
forward_deepgemm_masked
(
def
forward_deepgemm_masked
(
self
,
self
,
hidden_states_fp8
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
dispatch_output
:
DeepEPLLOutput
,
masked_m
:
torch
.
Tensor
,
expected_m
:
int
,
):
):
hidden_states_fp8
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
assert
self
.
activation
==
"silu"
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
9c138a04
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
from
__future__
import
annotations
import
logging
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Protocol
,
Tuple
,
Union
,
runtime_checkable
,
)
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers.moe.token_dispatcher.base_dispatcher
import
(
BaseDispatcher
,
BaseDispatcherConfig
,
DispatchOutput
,
DispatchOutputFormat
,
)
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -24,7 +44,6 @@ except ImportError:
...
@@ -24,7 +44,6 @@ except ImportError:
use_deepep
=
False
use_deepep
=
False
from
enum
import
Enum
,
IntEnum
,
auto
from
enum
import
Enum
,
IntEnum
,
auto
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
...
@@ -41,6 +60,37 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
DeepEPNormalOutput
(
NamedTuple
):
"""DeepEP normal dispatch output."""
hidden_states
:
torch
.
Tensor
|
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
topk_idx
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
num_recv_tokens_per_expert
:
List
[
int
]
@
property
def
format
(
self
)
->
DispatchOutputFormat
:
return
DispatchOutputFormat
.
deepep_normal
class
DeepEPLLOutput
(
NamedTuple
):
"""DeepEP low latency dispatch output."""
hidden_states_fp8
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
topk_idx
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
masked_m
:
torch
.
Tensor
expected_m
:
int
@
property
def
format
(
self
)
->
DispatchOutputFormat
:
return
DispatchOutputFormat
.
deepep_ll
assert
isinstance
(
DeepEPNormalOutput
,
DispatchOutput
)
assert
isinstance
(
DeepEPLLOutput
,
DispatchOutput
)
class
DeepEPDispatchMode
(
IntEnum
):
class
DeepEPDispatchMode
(
IntEnum
):
NORMAL
=
auto
()
NORMAL
=
auto
()
LOW_LATENCY
=
auto
()
LOW_LATENCY
=
auto
()
...
@@ -139,7 +189,7 @@ class DeepEPBuffer:
...
@@ -139,7 +189,7 @@ class DeepEPBuffer:
cls
.
_dispatch_mode
=
DeepEPDispatchMode
.
LOW_LATENCY
cls
.
_dispatch_mode
=
DeepEPDispatchMode
.
LOW_LATENCY
class
DeepEPConfig
:
class
DeepEPConfig
(
BaseDispatcherConfig
)
:
_instance
=
None
_instance
=
None
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -255,63 +305,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -255,63 +305,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
return
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
return
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
def
dispatch_b
(
self
,
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
):
def
dispatch_b
(
self
,
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
):
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
(
(
hidden_states
,
hidden_states
,
topk_idx
,
topk_idx
,
topk_weights
,
topk_weights
,
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert_list
,
event
,
event
,
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
)
)
=
self
.
_dispatch_core
(
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
return
DeepEPNormalOutput
(
)
hidden_states
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
)
return
(
hidden_states
,
topk_idx
,
topk_weights
,
None
,
num_recv_tokens_per_expert_list
,
None
,
None
,
None
,
)
else
:
(
hidden_states
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert_list
,
event
,
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_idx
,
topk_weights
,
previous_event
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
if
hidden_states
.
shape
[
0
]
>
0
:
reorder_topk_ids
,
seg_indptr
,
hidden_states
=
self
.
_deepep_permute
(
hidden_states
,
topk_idx
,
fp8_dtype
=
hidden_states
.
dtype
)
else
:
reorder_topk_ids
=
torch
.
empty
(
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
seg_indptr
=
torch
.
zeros
(
(
self
.
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
,
)
masked_m
=
expected_m
=
None
return
(
hidden_states
,
topk_idx
,
topk_weights
,
reorder_topk_ids
,
None
,
seg_indptr
,
masked_m
,
expected_m
,
)
def
_dispatch_core
(
def
_dispatch_core
(
self
,
self
,
...
@@ -343,7 +347,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -343,7 +347,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
recv_x
,
recv_x
,
recv_topk_idx
,
recv_topk_idx
,
recv_topk_weights
,
recv_topk_weights
,
num_recv_tokens_per_expert
_list
,
num_recv_tokens_per_expert
,
self
.
handle
,
self
.
handle
,
event
,
event
,
)
=
buffer
.
dispatch
(
)
=
buffer
.
dispatch
(
...
@@ -362,7 +366,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -362,7 +366,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
)
)
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
get_global_expert_distribution_recorder
().
on_deepep_dispatch_normal
(
num_recv_tokens_per_expert
_list
,
num_recv_tokens_per_expert
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_expert
=
num_tokens_per_expert
,
num_tokens_per_expert
=
num_tokens_per_expert
,
...
@@ -372,58 +376,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -372,58 +376,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
recv_x
,
recv_x
,
recv_topk_idx
,
recv_topk_idx
,
recv_topk_weights
,
recv_topk_weights
,
num_recv_tokens_per_expert
_list
,
num_recv_tokens_per_expert
,
event
,
event
,
)
)
def
_deepep_permute
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
fp8_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_block_quant
:
bool
=
False
,
):
"""
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
"""
if
_use_aiter
:
# skip permutation here as aiter fused_moe has fused inside
reorder_topk_ids
=
torch
.
empty
(
(
0
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
seg_indptr
=
torch
.
zeros
(
(
self
.
num_experts
+
1
,),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int64
)
return
reorder_topk_ids
,
seg_indptr
,
hidden_states
reorder_topk_ids
,
self
.
src2dst
,
seg_indptr
=
deepep_run_moe_deep_preprocess
(
topk_idx
,
self
.
num_experts
)
num_total_tokens
=
reorder_topk_ids
.
numel
()
gateup_input
=
torch
.
empty
(
(
int
(
num_total_tokens
),
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
(
fp8_dtype
if
(
use_fp8_w8a8
and
not
use_block_quant
)
else
hidden_states
.
dtype
),
)
# PreReorder
deepep_permute_triton_kernel
[(
hidden_states
.
shape
[
0
],)](
hidden_states
,
gateup_input
,
self
.
src2dst
,
topk_idx
,
None
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
return
reorder_topk_ids
,
seg_indptr
,
gateup_input
def
combine_a
(
def
combine_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -544,15 +500,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -544,15 +500,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m
masked_m
)
)
reorder_topk_ids
=
seg_indptr
=
None
return
DeepEPLLOutput
(
return
(
hidden_states
,
hidden_states
,
topk_idx
,
topk_idx
,
topk_weights
,
topk_weights
,
reorder_topk_ids
,
None
,
seg_indptr
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
)
)
...
@@ -636,7 +587,7 @@ class _Stage(Enum):
...
@@ -636,7 +587,7 @@ class _Stage(Enum):
AFTER_COMBINE_A
=
auto
()
AFTER_COMBINE_A
=
auto
()
class
DeepEPDispatcher
:
class
DeepEPDispatcher
(
BaseDispatcher
)
:
def
__init__
(
def
__init__
(
self
,
self
,
group
:
torch
.
distributed
.
ProcessGroup
,
group
:
torch
.
distributed
.
ProcessGroup
,
...
@@ -676,7 +627,7 @@ class DeepEPDispatcher:
...
@@ -676,7 +627,7 @@ class DeepEPDispatcher:
self
.
_stage
=
_Stage
.
INITIAL
self
.
_stage
=
_Stage
.
INITIAL
def
dispatch
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
def
dispatch
(
self
,
*
args
,
**
kwargs
)
->
DispatchOutput
:
self
.
dispatch_a
(
*
args
,
**
kwargs
)
self
.
dispatch_a
(
*
args
,
**
kwargs
)
ret
=
self
.
dispatch_b
()
ret
=
self
.
dispatch_b
()
return
ret
return
ret
...
...
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
0 → 100644
View file @
9c138a04
python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py
0 → 100644
View file @
9c138a04
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
enum
import
Enum
,
auto
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Protocol
,
runtime_checkable
import
torch
class
DispatchOutputFormat
(
Enum
):
standard
=
auto
()
deepep_normal
=
auto
()
deepep_ll
=
auto
()
def
is_standard
(
self
)
->
bool
:
return
self
==
DispatchOutputFormat
.
standard
def
is_deepep_normal
(
self
)
->
bool
:
return
self
==
DispatchOutputFormat
.
deepep_normal
def
is_deepep_ll
(
self
)
->
bool
:
return
self
==
DispatchOutputFormat
.
deepep_ll
@
runtime_checkable
class
DispatchOutput
(
Protocol
):
"""Protocol for dispatch outputs in different formats."""
@
property
def
format
(
self
)
->
DispatchOutputFormat
:
...
class
BaseDispatcherConfig
(
ABC
):
"""Base class for dispatcher configs."""
pass
class
BaseDispatcher
(
ABC
):
"""Base class for dispatchers."""
@
abstractmethod
def
dispatch
(
self
,
*
args
,
**
kwargs
)
->
DispatchOutput
:
pass
@
abstractmethod
def
combine
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
pass
python/sglang/srt/layers/moe/token_dispatcher/standard.py
0 → 100644
View file @
9c138a04
from
__future__
import
annotations
from
typing
import
NamedTuple
from
sglang.srt.layers.moe.token_dispatcher.base_dispatcher
import
(
DispatchOutput
,
DispatchOutputFormat
,
)
class
StandardDispatchOutput
(
NamedTuple
):
"""Standard dispatch output."""
@
property
def
format
(
self
)
->
DispatchOutputFormat
:
return
DispatchOutputFormat
.
standard
assert
isinstance
(
StandardDispatchOutput
,
DispatchOutput
)
python/sglang/srt/models/deepseek_v2.py
View file @
9c138a04
...
@@ -594,41 +594,13 @@ class DeepseekV2MoE(nn.Module):
...
@@ -594,41 +594,13 @@ class DeepseekV2MoE(nn.Module):
topk_weights
=
torch
.
empty
(
topk_weights
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
)
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states
,
topk_idx
,
topk_weights
,
reorder_topk_ids
,
num_recv_tokens_per_expert
,
seg_indptr
,
masked_m
,
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
reorder_topk_ids
=
reorder_topk_ids
,
seg_indptr
=
seg_indptr
,
masked_m
=
masked_m
,
expected_m
=
expected_m
,
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
)
)
if
self
.
ep_size
>
1
:
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
hidden_states
=
final_hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
)
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
x
=
shared_output
x
=
shared_output
...
@@ -689,8 +661,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -689,8 +661,7 @@ class DeepseekV2MoE(nn.Module):
def
op_dispatch_a
(
self
,
state
):
def
op_dispatch_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self
.
experts
.
deepep_dispatcher
.
dispatch_a
(
self
.
deepep_dispatcher
.
dispatch_a
(
hidden_states
=
state
.
hidden_states_mlp_input
,
hidden_states
=
state
.
hidden_states_mlp_input
,
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
...
@@ -703,46 +674,32 @@ class DeepseekV2MoE(nn.Module):
...
@@ -703,46 +674,32 @@ class DeepseekV2MoE(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
self
.
layer_id
):
):
(
state
.
dispatch_output
=
self
.
experts
.
deepep_dispatcher
.
dispatch_b
(
state
.
hidden_states_experts_input
,
state
.
topk_idx_dispatched
,
state
.
topk_weights_dispatched
,
state
.
reorder_topk_ids
,
state
.
num_recv_tokens_per_expert
,
state
.
seg_indptr
,
state
.
masked_m
,
state
.
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
def
op_experts
(
self
,
state
):
def
op_experts
(
self
,
state
):
state
.
hidden_states_experts_output
=
self
.
experts
(
state
.
hidden_states_experts_output
=
self
.
experts
.
moe_impl
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_input"
),
dispatch_output
=
state
.
dispatch_output
,
topk_idx
=
state
.
topk_idx_dispatched
,
topk_weights
=
state
.
topk_weights_dispatched
,
reorder_topk_ids
=
state
.
pop
(
"reorder_topk_ids"
),
seg_indptr
=
state
.
pop
(
"seg_indptr"
),
masked_m
=
state
.
pop
(
"masked_m"
),
expected_m
=
state
.
pop
(
"expected_m"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
forward_batch
=
state
.
forward_batch
,
)
)
def
op_combine_a
(
self
,
state
):
def
op_combine_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
self
.
deepep_dispatcher
.
combine_a
(
self
.
experts
.
deepep_dispatcher
.
combine_a
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
)
,
topk_idx
=
state
.
dispatch_output
.
topk_idx
,
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
)
,
topk_weights
=
state
.
dispatch_output
.
topk_weights
,
forward_batch
=
state
.
forward_batch
,
forward_batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
state
.
pop
(
"dispatch_output"
)
def
op_combine_b
(
self
,
state
):
def
op_combine_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
state
.
hidden_states_after_combine
=
self
.
deepep_dispatcher
.
combine_b
(
state
.
hidden_states_after_combine
=
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
self
.
experts
.
deepep_dispatcher
.
combine_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
)
def
op_output
(
self
,
state
):
def
op_output
(
self
,
state
):
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
9c138a04
...
@@ -144,19 +144,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -144,19 +144,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
)
)
self
.
top_k
=
config
.
num_experts_per_tok
self
.
top_k
=
config
.
num_experts_per_tok
self
.
deepep_dispatcher
=
MaybeTboDeepEPDispatcher
(
group
=
parallel_state
.
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
num_experts
=
self
.
num_experts
,
num_local_experts
=
config
.
num_experts
//
self
.
tp_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]],
async_finish
=
True
,
# TODO
return_recv_hook
=
True
,
)
def
forward
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -207,41 +194,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -207,41 +194,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
topk_weights
=
torch
.
empty
(
topk_weights
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
)
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states
,
topk_idx
,
topk_weights
,
reorder_topk_ids
,
num_recv_tokens_per_expert
,
seg_indptr
,
masked_m
,
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch
(
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
reorder_topk_ids
=
reorder_topk_ids
,
seg_indptr
=
seg_indptr
,
masked_m
=
masked_m
,
expected_m
=
expected_m
,
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
)
)
if
self
.
ep_size
>
1
:
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
hidden_states
=
final_hidden_states
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
)
return
final_hidden_states
return
final_hidden_states
def
op_gate
(
self
,
state
):
def
op_gate
(
self
,
state
):
...
@@ -278,8 +236,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -278,8 +236,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def
op_dispatch_a
(
self
,
state
):
def
op_dispatch_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self
.
experts
.
deepep_dispatcher
.
dispatch_a
(
self
.
deepep_dispatcher
.
dispatch_a
(
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
...
@@ -292,46 +249,32 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -292,46 +249,32 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
self
.
layer_id
):
):
(
state
.
dispatch_output
=
self
.
experts
.
deepep_dispatcher
.
dispatch_b
(
state
.
hidden_states_experts_input
,
state
.
topk_idx_dispatched
,
state
.
topk_weights_dispatched
,
state
.
reorder_topk_ids
,
state
.
num_recv_tokens_per_expert
,
state
.
seg_indptr
,
state
.
masked_m
,
state
.
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
def
op_experts
(
self
,
state
):
def
op_experts
(
self
,
state
):
state
.
hidden_states_experts_output
=
self
.
experts
(
state
.
hidden_states_experts_output
=
self
.
experts
.
moe_impl
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_input"
),
dispatch_output
=
state
.
dispatch_output
,
topk_idx
=
state
.
topk_idx_dispatched
,
topk_weights
=
state
.
topk_weights_dispatched
,
reorder_topk_ids
=
state
.
pop
(
"reorder_topk_ids"
),
seg_indptr
=
state
.
pop
(
"seg_indptr"
),
masked_m
=
state
.
pop
(
"masked_m"
),
expected_m
=
state
.
pop
(
"expected_m"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
forward_batch
=
state
.
forward_batch
,
)
)
def
op_combine_a
(
self
,
state
):
def
op_combine_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
self
.
deepep_dispatcher
.
combine_a
(
self
.
experts
.
deepep_dispatcher
.
combine_a
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
)
,
topk_idx
=
state
.
dispatch_output
.
topk_idx
,
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
)
,
topk_weights
=
state
.
dispatch_output
.
topk_weights
,
forward_batch
=
state
.
forward_batch
,
forward_batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
state
.
pop
(
"dispatch_output"
)
def
op_combine_b
(
self
,
state
):
def
op_combine_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
state
.
hidden_states_after_combine
=
self
.
deepep_dispatcher
.
combine_b
(
state
.
hidden_states_after_combine
=
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
self
.
experts
.
deepep_dispatcher
.
combine_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
)
def
op_output
(
self
,
state
):
def
op_output
(
self
,
state
):
...
...
python/sglang/srt/two_batch_overlap.py
View file @
9c138a04
from
__future__
import
annotations
import
dataclasses
import
dataclasses
import
logging
import
logging
from
dataclasses
import
replace
from
dataclasses
import
replace
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Union
import
torch
import
torch
...
@@ -20,6 +22,9 @@ from sglang.srt.operations_strategy import OperationsStrategy
...
@@ -20,6 +22,9 @@ from sglang.srt.operations_strategy import OperationsStrategy
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.utils
import
BumpAllocator
,
DeepEPMode
,
get_bool_env_var
from
sglang.srt.utils
import
BumpAllocator
,
DeepEPMode
,
get_bool_env_var
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DispatchOutput
_tbo_debug
=
get_bool_env_var
(
"SGLANG_TBO_DEBUG"
)
_tbo_debug
=
get_bool_env_var
(
"SGLANG_TBO_DEBUG"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -802,7 +807,7 @@ class MaybeTboDeepEPDispatcher:
...
@@ -802,7 +807,7 @@ class MaybeTboDeepEPDispatcher:
def
_execute
(
self
,
name
,
tbo_subbatch_index
:
Optional
[
int
]
=
None
,
**
kwargs
):
def
_execute
(
self
,
name
,
tbo_subbatch_index
:
Optional
[
int
]
=
None
,
**
kwargs
):
return
getattr
(
self
.
_inners
[
tbo_subbatch_index
or
0
],
name
)(
**
kwargs
)
return
getattr
(
self
.
_inners
[
tbo_subbatch_index
or
0
],
name
)(
**
kwargs
)
def
dispatch
(
self
,
**
kwargs
):
def
dispatch
(
self
,
**
kwargs
)
->
DispatchOutput
:
return
self
.
_execute
(
"dispatch"
,
**
kwargs
)
return
self
.
_execute
(
"dispatch"
,
**
kwargs
)
def
dispatch_a
(
self
,
**
kwargs
):
def
dispatch_a
(
self
,
**
kwargs
):
...
@@ -811,7 +816,7 @@ class MaybeTboDeepEPDispatcher:
...
@@ -811,7 +816,7 @@ class MaybeTboDeepEPDispatcher:
def
dispatch_b
(
self
,
**
kwargs
):
def
dispatch_b
(
self
,
**
kwargs
):
return
self
.
_execute
(
"dispatch_b"
,
**
kwargs
)
return
self
.
_execute
(
"dispatch_b"
,
**
kwargs
)
def
combine
(
self
,
**
kwargs
):
def
combine
(
self
,
**
kwargs
)
->
torch
.
Tensor
:
return
self
.
_execute
(
"combine"
,
**
kwargs
)
return
self
.
_execute
(
"combine"
,
**
kwargs
)
def
combine_a
(
self
,
**
kwargs
):
def
combine_a
(
self
,
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment