Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bfc3b3f7
Unverified
Commit
bfc3b3f7
authored
Oct 20, 2025
by
Cheng Wan
Committed by
GitHub
Oct 20, 2025
Browse files
[9/N] MoE Refactor: cleanup dispatcher interfaces (#11847)
parent
da5bde4d
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
356 additions
and
410 deletions
+356
-410
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+17
-0
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+3
-1
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+69
-99
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+44
-35
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
+2
-0
python/sglang/srt/layers/moe/token_dispatcher/base.py
python/sglang/srt/layers/moe/token_dispatcher/base.py
+1
-1
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+86
-91
python/sglang/srt/layers/moe/token_dispatcher/mooncake.py
python/sglang/srt/layers/moe/token_dispatcher/mooncake.py
+37
-39
python/sglang/srt/layers/moe/token_dispatcher/standard.py
python/sglang/srt/layers/moe/token_dispatcher/standard.py
+46
-0
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+3
-2
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+4
-0
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-0
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
.../sglang/srt/model_executor/piecewise_cuda_graph_runner.py
+4
-0
python/sglang/srt/models/bailing_moe.py
python/sglang/srt/models/bailing_moe.py
+4
-42
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+14
-46
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+1
-16
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+3
-7
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+13
-30
No files found.
python/sglang/srt/layers/dp_attention.py
View file @
bfc3b3f7
...
@@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper:
...
@@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper:
_global_dp_buffer_len
:
int
_global_dp_buffer_len
:
int
_local_dp_buffer_len
:
int
_local_dp_buffer_len
:
int
_global_num_tokens
:
Optional
[
List
[
int
]]
_global_num_tokens
:
Optional
[
List
[
int
]]
_is_extend_in_batch
:
bool
@
classmethod
@
classmethod
def
set_metadata
(
cls
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
def
set_metadata
(
cls
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
...
@@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper:
...
@@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper:
def
get_dp_device
(
cls
)
->
torch
.
device
:
def
get_dp_device
(
cls
)
->
torch
.
device
:
return
cls
.
_device
return
cls
.
_device
@
classmethod
def
set_is_extend_in_batch
(
cls
,
is_extend_in_batch
:
bool
):
cls
.
_is_extend_in_batch
=
is_extend_in_batch
@
classmethod
def
get_is_extend_in_batch
(
cls
)
->
bool
:
return
cls
.
_is_extend_in_batch
def
set_dp_buffer_len
(
def
set_dp_buffer_len
(
global_dp_buffer_len
:
int
,
global_dp_buffer_len
:
int
,
...
@@ -188,6 +197,14 @@ def get_dp_device() -> torch.device:
...
@@ -188,6 +197,14 @@ def get_dp_device() -> torch.device:
return
_DpGatheredBufferWrapper
.
get_dp_device
()
return
_DpGatheredBufferWrapper
.
get_dp_device
()
def
set_is_extend_in_batch
(
is_extend_in_batch
:
bool
):
_DpGatheredBufferWrapper
.
set_is_extend_in_batch
(
is_extend_in_batch
)
def
get_is_extend_in_batch
()
->
bool
:
return
_DpGatheredBufferWrapper
.
get_is_extend_in_batch
()
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
if
not
enable_dp_attention
:
if
not
enable_dp_attention
:
return
tp_rank
,
tp_size
,
0
return
tp_rank
,
tp_size
,
0
...
...
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
bfc3b3f7
...
@@ -566,7 +566,9 @@ def ep_scatter(
...
@@ -566,7 +566,9 @@ def ep_scatter(
scale_hidden_size
=
ceil_div
(
scale_hidden_size
,
4
)
scale_hidden_size
=
ceil_div
(
scale_hidden_size
,
4
)
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
assert
recv_x_scale
.
dtype
==
output_tensor_scale
.
dtype
assert
(
recv_x_scale
.
dtype
==
output_tensor_scale
.
dtype
),
f
"recv_x_scale.dtype:
{
recv_x_scale
.
dtype
}
, output_tensor_scale.dtype:
{
output_tensor_scale
.
dtype
}
"
assert
recv_x_scale
.
shape
[
1
]
==
output_tensor_scale
.
shape
[
1
]
==
scale_hidden_size
assert
recv_x_scale
.
shape
[
1
]
==
output_tensor_scale
.
shape
[
1
]
==
scale_hidden_size
_fwd_kernel_ep_scatter_1
[(
grid
,)](
_fwd_kernel_ep_scatter_1
[(
grid
,)](
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
bfc3b3f7
...
@@ -20,18 +20,14 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
...
@@ -20,18 +20,14 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
tma_align_input_scale
,
tma_align_input_scale
,
)
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
is_fp8_fnuz
,
sglang_per_token_group_quant_fp8
,
sglang_per_token_group_quant_fp8
,
)
)
from
sglang.srt.layers.quantization.modelopt_quant
import
(
CUTEDSL_MOE_NVFP4_DISPATCH
,
ModelOptNvFp4FusedMoEMethod
,
)
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
,
W4AFp8MoEMethod
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
,
W4AFp8MoEMethod
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
from
sglang.srt.utils.offloader
import
get_offloader
from
sglang.srt.utils.offloader
import
get_offloader
...
@@ -109,23 +105,6 @@ class DeepEPMoE(FusedMoE):
...
@@ -109,23 +105,6 @@ class DeepEPMoE(FusedMoE):
self
.
deepep_mode
=
get_deepep_mode
()
self
.
deepep_mode
=
get_deepep_mode
()
# TODO: move to the beginning of the file
from
sglang.srt.distributed.parallel_state
import
get_tp_group
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
self
.
deepep_dispatcher
=
MaybeTboDeepEPDispatcher
(
group
=
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
num_experts
=
self
.
num_experts
,
num_local_experts
=
self
.
num_local_experts
,
hidden_size
=
hidden_size
,
params_dtype
=
params_dtype
,
deepep_mode
=
self
.
deepep_mode
,
async_finish
=
True
,
# TODO
return_recv_hook
=
True
,
)
if
self
.
deepep_mode
.
enable_low_latency
()
and
not
_is_npu
:
if
self
.
deepep_mode
.
enable_low_latency
()
and
not
_is_npu
:
# NPU supports low_latency deepep without deepgemm
# NPU supports low_latency deepep without deepgemm
assert
(
assert
(
...
@@ -165,19 +144,16 @@ class DeepEPMoE(FusedMoE):
...
@@ -165,19 +144,16 @@ class DeepEPMoE(FusedMoE):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_shared_experts
=
None
,
forward_shared_experts
=
None
,
alt_stream
=
None
,
alt_stream
=
None
,
disable_sbo
=
False
,
disable_sbo
=
False
,
):
):
# We have to call SBO inside MoE to be compatible with hooks used in offloading
# We have to call SBO inside MoE to be compatible with hooks used in offloading
return
single_batch_overlap
.
execute_sbo
(
return
single_batch_overlap
.
execute_sbo
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_output
=
topk_output
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
# SBO args
# SBO args
experts
=
self
,
experts
=
self
,
forward_shared_experts
=
forward_shared_experts
,
forward_shared_experts
=
forward_shared_experts
,
...
@@ -188,25 +164,14 @@ class DeepEPMoE(FusedMoE):
...
@@ -188,25 +164,14 @@ class DeepEPMoE(FusedMoE):
def
dispatch
(
def
dispatch
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
):
return
self
.
deepep_
dispatcher
.
dispatch
(
return
self
.
dispatcher
.
dispatch
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_output
=
topk_output
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
input_global_scale
=
(
self
.
w13_input_scale_quant
if
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
and
self
.
quant_method
.
enable_flashinfer_cutedsl_moe
and
CUTEDSL_MOE_NVFP4_DISPATCH
else
None
),
)
)
def
moe_impl
(
def
run_moe_core
(
self
,
self
,
dispatch_output
:
DispatchOutput
,
dispatch_output
:
DispatchOutput
,
down_gemm_overlap_args
:
Optional
[
DownGemmOverlapArgs
]
=
None
,
down_gemm_overlap_args
:
Optional
[
DownGemmOverlapArgs
]
=
None
,
...
@@ -240,16 +205,14 @@ class DeepEPMoE(FusedMoE):
...
@@ -240,16 +205,14 @@ class DeepEPMoE(FusedMoE):
def
combine
(
def
combine
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
overlap_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
overlap_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
):
return
self
.
deepep_
dispatcher
.
combine
(
return
self
.
dispatcher
.
combine
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_id
x
=
topk_id
x
,
topk_id
s
=
topk_id
s
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
overlap_args
=
overlap_args
,
overlap_args
=
overlap_args
,
)
)
...
@@ -257,9 +220,9 @@ class DeepEPMoE(FusedMoE):
...
@@ -257,9 +220,9 @@ class DeepEPMoE(FusedMoE):
self
,
self
,
dispatch_output
:
Union
[
DeepEPNormalOutput
,
DeepEPLLOutput
],
dispatch_output
:
Union
[
DeepEPNormalOutput
,
DeepEPLLOutput
],
):
):
hidden_states
,
topk_id
x
,
topk_weights
=
(
hidden_states
,
topk_id
s
,
topk_weights
=
(
dispatch_output
.
hidden_states
,
dispatch_output
.
hidden_states
,
dispatch_output
.
topk_id
x
,
dispatch_output
.
topk_id
s
,
dispatch_output
.
topk_weights
,
dispatch_output
.
topk_weights
,
)
)
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
...
@@ -267,15 +230,15 @@ class DeepEPMoE(FusedMoE):
...
@@ -267,15 +230,15 @@ class DeepEPMoE(FusedMoE):
# in original deepep, idx == -1 meaning invalid and will not be processed.
# in original deepep, idx == -1 meaning invalid and will not be processed.
# aiter does not accept -1, we use a expert mask to make these idx invalid
# aiter does not accept -1, we use a expert mask to make these idx invalid
# (idx == num_local_experts) meaning not used in aiter fused_moe
# (idx == num_local_experts) meaning not used in aiter fused_moe
topk_id
x
_copy
=
topk_id
x
.
to
(
torch
.
int32
)
topk_id
s
_copy
=
topk_id
s
.
to
(
torch
.
int32
)
topk_id
x
_copy
[
topk_id
x
_copy
==
-
1
]
=
self
.
num_local_experts
topk_id
s
_copy
[
topk_id
s
_copy
==
-
1
]
=
self
.
num_local_experts
return
fused_moe
(
return
fused_moe
(
hidden_states
,
hidden_states
,
self
.
w13_weight
,
self
.
w13_weight
,
self
.
w2_weight
,
self
.
w2_weight
,
topk_weights
,
topk_weights
,
topk_id
x
_copy
,
topk_id
s
_copy
,
w1_scale
=
self
.
w13_weight_scale_inv
,
w1_scale
=
self
.
w13_weight_scale_inv
,
w2_scale
=
self
.
w2_weight_scale_inv
,
w2_scale
=
self
.
w2_weight_scale_inv
,
quant_type
=
QuantType
.
per_128x128
,
quant_type
=
QuantType
.
per_128x128
,
...
@@ -291,18 +254,21 @@ class DeepEPMoE(FusedMoE):
...
@@ -291,18 +254,21 @@ class DeepEPMoE(FusedMoE):
self
,
self
,
dispatch_output
:
DeepEPNormalOutput
,
dispatch_output
:
DeepEPNormalOutput
,
):
):
hidden_states_fp8
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
=
(
(
dispatch_output
hidden_states
,
)
hidden_states_scale
,
hidden_states_fp8
,
hidden_states_scale
=
hidden_states_fp8
topk_ids
,
topk_weights
,
num_recv_tokens_per_expert
,
)
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
self
.
moe_runner_config
.
activation
==
"silu"
if
num_recv_tokens_per_expert
is
None
:
if
num_recv_tokens_per_expert
is
None
:
return
hidden_states
_fp8
.
bfloat16
()
return
hidden_states
.
bfloat16
()
all_tokens
=
sum
(
num_recv_tokens_per_expert
)
all_tokens
=
sum
(
num_recv_tokens_per_expert
)
if
all_tokens
<=
0
:
if
all_tokens
<=
0
:
return
hidden_states
_fp8
.
bfloat16
()
return
hidden_states
.
bfloat16
()
M
,
K
=
hidden_states
_fp8
.
size
()
M
,
K
=
hidden_states
.
size
()
N
=
self
.
w13_weight
.
size
(
1
)
N
=
self
.
w13_weight
.
size
(
1
)
scale_block_size
=
128
scale_block_size
=
128
...
@@ -323,35 +289,35 @@ class DeepEPMoE(FusedMoE):
...
@@ -323,35 +289,35 @@ class DeepEPMoE(FusedMoE):
),
),
)
)
hidden_states_
fp8_
shape
=
hidden_states
_fp8
.
shape
hidden_states_shape
=
hidden_states
.
shape
hidden_states_
fp8_
device
=
hidden_states
_fp8
.
device
hidden_states_device
=
hidden_states
.
device
hidden_states_
fp8_
dtype
=
hidden_states
_fp8
.
dtype
hidden_states_dtype
=
hidden_states
.
dtype
input_tensor
=
[
input_tensor
=
[
torch
.
empty
(
torch
.
empty
(
(
all_tokens
,
K
),
(
all_tokens
,
K
),
device
=
hidden_states
_fp8
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
_fp8
.
dtype
,
dtype
=
hidden_states
.
dtype
,
),
),
(
(
# TODO check whether need `zeros`
# TODO check whether need `zeros`
torch
.
zeros
(
torch
.
zeros
(
(
ceil_div
(
K
//
128
,
4
),
all_tokens
),
(
ceil_div
(
K
//
128
,
4
),
all_tokens
),
device
=
hidden_states
_fp8
.
device
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
).
transpose
(
0
,
1
)
).
transpose
(
0
,
1
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
torch
.
empty
(
else
torch
.
empty
(
(
all_tokens
,
K
//
128
),
(
all_tokens
,
K
//
128
),
device
=
hidden_states
_fp8
.
device
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
)
),
),
]
]
m_indices
=
torch
.
empty
(
m_indices
=
torch
.
empty
(
all_tokens
,
device
=
hidden_states
_fp8
.
device
,
dtype
=
torch
.
int32
all_tokens
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int32
)
)
output_index
=
torch
.
empty_like
(
topk_id
x
)
output_index
=
torch
.
empty_like
(
topk_id
s
)
if
get_offloader
().
forbid_copy_engine_usage
:
if
get_offloader
().
forbid_copy_engine_usage
:
num_recv_tokens_per_expert_gpu
=
copy_list_to_gpu_no_ce
(
num_recv_tokens_per_expert_gpu
=
copy_list_to_gpu_no_ce
(
...
@@ -367,9 +333,9 @@ class DeepEPMoE(FusedMoE):
...
@@ -367,9 +333,9 @@ class DeepEPMoE(FusedMoE):
expert_start_loc
=
torch
.
empty_like
(
num_recv_tokens_per_expert_gpu
)
expert_start_loc
=
torch
.
empty_like
(
num_recv_tokens_per_expert_gpu
)
ep_scatter
(
ep_scatter
(
hidden_states
_fp8
,
hidden_states
,
hidden_states_scale
,
hidden_states_scale
,
topk_id
x
,
topk_id
s
,
num_recv_tokens_per_expert_gpu
,
num_recv_tokens_per_expert_gpu
,
expert_start_loc
,
expert_start_loc
,
input_tensor
[
0
],
input_tensor
[
0
],
...
@@ -378,11 +344,11 @@ class DeepEPMoE(FusedMoE):
...
@@ -378,11 +344,11 @@ class DeepEPMoE(FusedMoE):
output_index
,
output_index
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
)
dispose_tensor
(
hidden_states
_fp8
)
dispose_tensor
(
hidden_states
)
gateup_output
=
torch
.
empty
(
gateup_output
=
torch
.
empty
(
(
all_tokens
,
N
),
(
all_tokens
,
N
),
device
=
hidden_states_
fp8_
device
,
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
)
)
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
...
@@ -403,7 +369,7 @@ class DeepEPMoE(FusedMoE):
...
@@ -403,7 +369,7 @@ class DeepEPMoE(FusedMoE):
del
gateup_output
del
gateup_output
down_output
=
torch
.
empty
(
down_output
=
torch
.
empty
(
(
all_tokens
,
K
),
(
all_tokens
,
K
),
device
=
hidden_states_
fp8_
device
,
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
)
)
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
...
@@ -425,11 +391,11 @@ class DeepEPMoE(FusedMoE):
...
@@ -425,11 +391,11 @@ class DeepEPMoE(FusedMoE):
del
down_input_fp8
,
down_input_scale
del
down_input_fp8
,
down_input_scale
gather_out
=
torch
.
empty
(
gather_out
=
torch
.
empty
(
hidden_states_
fp8_
shape
,
hidden_states_shape
,
device
=
hidden_states_
fp8_
device
,
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
)
)
ep_gather
(
down_output
,
topk_id
x
,
topk_weights
,
output_index
,
gather_out
)
ep_gather
(
down_output
,
topk_id
s
,
topk_weights
,
output_index
,
gather_out
)
return
gather_out
return
gather_out
...
@@ -438,13 +404,13 @@ class DeepEPMoE(FusedMoE):
...
@@ -438,13 +404,13 @@ class DeepEPMoE(FusedMoE):
dispatch_output
:
DeepEPLLOutput
,
dispatch_output
:
DeepEPLLOutput
,
down_gemm_overlap_args
:
Optional
[
DownGemmOverlapArgs
],
down_gemm_overlap_args
:
Optional
[
DownGemmOverlapArgs
],
):
):
hidden_states
,
_
,
_
,
masked_m
,
_
=
dispatch_output
hidden_states
,
hidden_states_scale
,
_
,
_
,
masked_m
,
_
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
self
.
moe_runner_config
.
activation
==
"silu"
output
=
self
.
quant_method
.
apply_without_routing_weights
(
output
=
self
.
quant_method
.
apply_without_routing_weights
(
layer
=
self
,
layer
=
self
,
x
=
hidden_states
,
x
=
(
hidden_states
,
hidden_states_scale
),
masked_m
=
masked_m
,
masked_m
=
masked_m
,
moe_runner_config
=
self
.
moe_runner_config
,
moe_runner_config
=
self
.
moe_runner_config
,
down_gemm_overlap_args
=
down_gemm_overlap_args
,
down_gemm_overlap_args
=
down_gemm_overlap_args
,
...
@@ -466,25 +432,28 @@ class DeepEPMoE(FusedMoE):
...
@@ -466,25 +432,28 @@ class DeepEPMoE(FusedMoE):
self
,
self
,
dispatch_output
:
DeepEPLLOutput
,
dispatch_output
:
DeepEPLLOutput
,
):
):
hidden_states
_fp8
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
hidden_states
,
hidden_states_scale
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
(
hidden_states_scale
.
dtype
==
torch
.
float32
),
f
"hidden_states_scale.dtype:
{
hidden_states_scale
.
dtype
}
"
# GroupGemm-0
# GroupGemm-0
num_groups
,
m
,
k
=
hidden_states
_fp8
[
0
]
.
size
()
num_groups
,
m
,
k
=
hidden_states
.
size
()
n
=
self
.
w13_weight
.
size
(
1
)
n
=
self
.
w13_weight
.
size
(
1
)
expected_m
=
min
(
expected_m
,
m
)
expected_m
=
min
(
expected_m
,
m
)
gateup_output
=
torch
.
empty
(
gateup_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
hidden_states
_fp8
[
0
]
.
device
,
dtype
=
torch
.
bfloat16
(
num_groups
,
m
,
n
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
hidden_states
_fp8
,
(
hidden_states
,
hidden_states_scale
)
,
self
.
w13_weight_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
gateup_output
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
)
)
dispose_tensor
(
hidden_states
_fp8
[
0
]
)
dispose_tensor
(
hidden_states
)
# Act
# Act
down_input
=
torch
.
empty
(
down_input
=
torch
.
empty
(
...
@@ -557,11 +526,9 @@ class DeepEPMoE(FusedMoE):
...
@@ -557,11 +526,9 @@ class DeepEPMoE(FusedMoE):
def
_forward_normal
(
dispatch_output
:
DeepEPNormalOutput
):
def
_forward_normal
(
dispatch_output
:
DeepEPNormalOutput
):
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
assert
isinstance
(
dispatch_output
,
DeepEPNormalOutput
)
assert
isinstance
(
dispatch_output
,
DeepEPNormalOutput
)
hidden_states
,
_
,
_
,
num_recv_tokens_per_expert
=
dispatch_output
hidden_states
,
hidden_states_scale
,
_
,
_
,
num_recv_tokens_per_expert
=
(
dispatch_output
if
isinstance
(
hidden_states
,
tuple
):
)
per_token_scale
=
hidden_states
[
1
]
hidden_states
=
hidden_states
[
0
]
group_list
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int64
).
to
(
group_list
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int64
).
to
(
hidden_states
.
device
hidden_states
.
device
...
@@ -571,7 +538,7 @@ class DeepEPMoE(FusedMoE):
...
@@ -571,7 +538,7 @@ class DeepEPMoE(FusedMoE):
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
.
permute
(
0
,
2
,
1
)],
weight
=
[
self
.
w13_weight
.
permute
(
0
,
2
,
1
)],
# per_token_scale=[
per_token
_scale],
# per_token_scale=[
hidden_states
_scale],
split_item
=
2
,
split_item
=
2
,
group_list_type
=
group_list_type
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_type
=
0
,
...
@@ -591,7 +558,7 @@ class DeepEPMoE(FusedMoE):
...
@@ -591,7 +558,7 @@ class DeepEPMoE(FusedMoE):
)[
0
]
)[
0
]
else
:
else
:
if
not
get_bool_env_var
(
"DEEP_NORMAL_MODE_USE_INT8_QUANT"
):
if
not
get_bool_env_var
(
"DEEP_NORMAL_MODE_USE_INT8_QUANT"
):
hidden_states
,
per_token
_scale
=
torch_npu
.
npu_dynamic_quant
(
hidden_states
,
hidden_states
_scale
=
torch_npu
.
npu_dynamic_quant
(
hidden_states
hidden_states
)
)
# gmm1: gate_up_proj
# gmm1: gate_up_proj
...
@@ -599,7 +566,7 @@ class DeepEPMoE(FusedMoE):
...
@@ -599,7 +566,7 @@ class DeepEPMoE(FusedMoE):
x
=
[
hidden_states
],
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
],
weight
=
[
self
.
w13_weight
],
scale
=
[
self
.
w13_weight_scale
.
to
(
output_dtype
)],
scale
=
[
self
.
w13_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
per_token
_scale
],
per_token_scale
=
[
hidden_states
_scale
],
split_item
=
2
,
split_item
=
2
,
group_list_type
=
group_list_type
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_type
=
0
,
...
@@ -631,11 +598,14 @@ class DeepEPMoE(FusedMoE):
...
@@ -631,11 +598,14 @@ class DeepEPMoE(FusedMoE):
def
_forward_ll
(
dispatch_output
:
DeepEPLLOutput
):
def
_forward_ll
(
dispatch_output
:
DeepEPLLOutput
):
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
assert
isinstance
(
dispatch_output
,
DeepEPLLOutput
)
assert
isinstance
(
dispatch_output
,
DeepEPLLOutput
)
hidden_states
,
topk_idx
,
topk_weights
,
group_list
,
_
=
dispatch_output
(
hidden_states
,
if
isinstance
(
hidden_states
,
tuple
):
hidden_states_scale
,
per_token_scale
=
hidden_states
[
1
]
topk_ids
,
hidden_states
=
hidden_states
[
0
]
topk_weights
,
group_list
,
_
,
)
=
dispatch_output
group_list
=
group_list
.
to
(
torch
.
int64
)
group_list
=
group_list
.
to
(
torch
.
int64
)
...
@@ -644,7 +614,7 @@ class DeepEPMoE(FusedMoE):
...
@@ -644,7 +614,7 @@ class DeepEPMoE(FusedMoE):
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
.
permute
(
0
,
2
,
1
)],
weight
=
[
self
.
w13_weight
.
permute
(
0
,
2
,
1
)],
# per_token_scale=[
per_token
_scale],
# per_token_scale=[
hidden_states
_scale],
split_item
=
2
,
split_item
=
2
,
group_list_type
=
group_list_type
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_type
=
0
,
...
@@ -678,7 +648,7 @@ class DeepEPMoE(FusedMoE):
...
@@ -678,7 +648,7 @@ class DeepEPMoE(FusedMoE):
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dequant_swiglu_quant
(
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dequant_swiglu_quant
(
x
=
hidden_states
,
x
=
hidden_states
,
weight_scale
=
self
.
w13_weight_scale
.
to
(
torch
.
float32
),
weight_scale
=
self
.
w13_weight_scale
.
to
(
torch
.
float32
),
activation_scale
=
per_token
_scale
,
activation_scale
=
hidden_states
_scale
,
bias
=
None
,
bias
=
None
,
quant_scale
=
None
,
quant_scale
=
None
,
quant_offset
=
None
,
quant_offset
=
None
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
bfc3b3f7
...
@@ -11,14 +11,19 @@ from sglang.srt.distributed import (
...
@@ -11,14 +11,19 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size
,
get_moe_expert_parallel_world_size
,
get_moe_tensor_parallel_rank
,
get_moe_tensor_parallel_rank
,
get_moe_tensor_parallel_world_size
,
get_moe_tensor_parallel_world_size
,
get_tp_group
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe
import
(
from
sglang.srt.layers.moe
import
(
MoeRunnerConfig
,
MoeRunnerConfig
,
get_deepep_mode
,
get_moe_a2a_backend
,
get_moe_runner_backend
,
get_moe_runner_backend
,
should_use_flashinfer_trtllm_moe
,
should_use_flashinfer_trtllm_moe
,
)
)
from
sglang.srt.layers.moe.token_dispatcher
import
CombineInput
,
DispatchOutput
from
sglang.srt.layers.moe.token_dispatcher.base
import
BaseDispatcher
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
StandardDispatcher
,
StandardDispatcher
,
StandardDispatchOutput
,
StandardDispatchOutput
,
...
@@ -32,6 +37,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
...
@@ -32,6 +37,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
...
@@ -71,6 +77,27 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
...
@@ -71,6 +77,27 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
return
tile_tokens_dim
return
tile_tokens_dim
def
create_moe_dispatcher
(
moe_runner_config
:
MoeRunnerConfig
)
->
BaseDispatcher
:
a2a_backend
=
get_moe_a2a_backend
()
if
a2a_backend
.
is_none
():
return
StandardDispatcher
(
moe_runner_config
)
elif
a2a_backend
.
is_deepep
():
return
MaybeTboDeepEPDispatcher
(
group
=
get_tp_group
().
device_group
,
router_topk
=
moe_runner_config
.
top_k
,
permute_fusion
=
True
,
num_experts
=
moe_runner_config
.
num_experts
,
num_local_experts
=
moe_runner_config
.
num_local_experts
,
hidden_size
=
moe_runner_config
.
hidden_size
,
params_dtype
=
moe_runner_config
.
params_dtype
,
deepep_mode
=
get_deepep_mode
(),
async_finish
=
True
,
return_recv_hook
=
True
,
)
else
:
raise
NotImplementedError
(
f
"Unsupported a2a backend:
{
a2a_backend
}
"
)
class
FusedMoeWeightScaleSupported
(
Enum
):
class
FusedMoeWeightScaleSupported
(
Enum
):
TENSOR
=
"tensor"
TENSOR
=
"tensor"
CHANNEL
=
"channel"
CHANNEL
=
"channel"
...
@@ -132,8 +159,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -132,8 +159,6 @@ class FusedMoE(torch.nn.Module):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
expert_map_cpu
=
None
self
.
expert_map_gpu
=
None
enable_flashinfer_cutlass_moe
=
get_moe_runner_backend
().
is_flashinfer_cutlass
()
enable_flashinfer_cutlass_moe
=
get_moe_runner_backend
().
is_flashinfer_cutlass
()
...
@@ -149,19 +174,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -149,19 +174,6 @@ class FusedMoE(torch.nn.Module):
assert
num_experts
%
self
.
moe_ep_size
==
0
assert
num_experts
%
self
.
moe_ep_size
==
0
self
.
num_local_experts
=
num_experts
//
self
.
moe_ep_size
self
.
num_local_experts
=
num_experts
//
self
.
moe_ep_size
if
self
.
moe_ep_size
>
1
:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1
self
.
expert_map_cpu
=
torch
.
full
(
(
self
.
num_experts
,),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
# Create a expert map for the local experts
self
.
expert_map_cpu
[
self
.
moe_ep_rank
*
self
.
num_local_experts
:
(
self
.
moe_ep_rank
+
1
)
*
self
.
num_local_experts
]
=
torch
.
arange
(
0
,
self
.
num_local_experts
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
assert
intermediate_size
%
self
.
moe_tp_size
==
0
assert
intermediate_size
%
self
.
moe_tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
moe_tp_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
moe_tp_size
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
...
@@ -219,7 +231,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -219,7 +231,7 @@ class FusedMoE(torch.nn.Module):
)
)
self
.
quant_method
.
create_moe_runner
(
self
,
self
.
moe_runner_config
)
self
.
quant_method
.
create_moe_runner
(
self
,
self
.
moe_runner_config
)
self
.
dispatcher
=
StandardDispatcher
(
)
self
.
dispatcher
=
create_moe_dispatcher
(
self
.
moe_runner_config
)
self
.
should_fuse_routed_scaling_factor_in_topk
=
isinstance
(
self
.
should_fuse_routed_scaling_factor_in_topk
=
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
...
@@ -453,9 +465,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -453,9 +465,12 @@ class FusedMoE(torch.nn.Module):
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
def
_map_global_expert_id_to_local_expert_id
(
self
,
expert_id
:
int
)
->
int
:
def
_map_global_expert_id_to_local_expert_id
(
self
,
expert_id
:
int
)
->
int
:
if
self
.
expert_map_cpu
is
None
:
start_idx
=
self
.
moe_ep_rank
*
self
.
num_local_experts
return
expert_id
end_idx
=
(
self
.
moe_ep_rank
+
1
)
*
self
.
num_local_experts
return
self
.
expert_map_cpu
[
expert_id
].
item
()
if
start_idx
<=
expert_id
<
end_idx
:
return
expert_id
-
start_idx
else
:
return
-
1
def
weight_loader
(
def
weight_loader
(
self
,
self
,
...
@@ -804,32 +819,18 @@ class FusedMoE(torch.nn.Module):
...
@@ -804,32 +819,18 @@ class FusedMoE(torch.nn.Module):
origin_hidden_states_dim
=
hidden_states
.
shape
[
-
1
]
origin_hidden_states_dim
=
hidden_states
.
shape
[
-
1
]
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
if
self
.
moe_ep_size
>
1
and
not
self
.
enable_flashinfer_cutlass_moe
:
if
self
.
expert_map_cpu
is
not
None
and
self
.
expert_map_gpu
is
None
:
# If we are in EP mode, we need to move the expert map to GPU.
self
.
expert_map_gpu
=
self
.
expert_map_cpu
.
to
(
device
=
"cuda"
)
if
self
.
expert_map_gpu
is
not
None
:
if
TopKOutputChecker
.
format_is_standard
(
topk_output
):
topk_output
=
topk_output
.
_replace
(
topk_ids
=
self
.
expert_map_gpu
[
topk_output
.
topk_ids
]
)
elif
TopKOutputChecker
.
format_is_triton_kernel
(
topk_output
):
raise
NotImplementedError
()
dispatch_output
=
self
.
dispatcher
.
dispatch
(
dispatch_output
=
self
.
dispatcher
.
dispatch
(
hidden_states
=
hidden_states
,
topk_output
=
topk_output
hidden_states
=
hidden_states
,
topk_output
=
topk_output
)
)
# TODO: consider using symmetric memory
combine_input
=
self
.
run_moe_core
(
combine_input
=
self
.
quant_method
.
apply
(
layer
=
self
,
dispatch_output
=
dispatch_output
,
dispatch_output
=
dispatch_output
,
**
kwargs
,
**
kwargs
,
)
)
final_hidden_states
=
self
.
dispatcher
.
combine
(
combine_input
)
final_hidden_states
=
self
.
dispatcher
.
combine
(
combine_input
)
# TODO: should we add some conditions here?
final_hidden_states
=
final_hidden_states
[
final_hidden_states
=
final_hidden_states
[
...,
:
origin_hidden_states_dim
...,
:
origin_hidden_states_dim
].
contiguous
()
].
contiguous
()
...
@@ -839,6 +840,14 @@ class FusedMoE(torch.nn.Module):
...
@@ -839,6 +840,14 @@ class FusedMoE(torch.nn.Module):
return
final_hidden_states
return
final_hidden_states
def
run_moe_core
(
self
,
dispatch_output
:
DispatchOutput
,
**
kwargs
)
->
CombineInput
:
# TODO: consider using symmetric memory
return
self
.
quant_method
.
apply
(
layer
=
self
,
dispatch_output
=
dispatch_output
,
**
kwargs
,
)
@
classmethod
@
classmethod
def
make_expert_params_mapping
(
def
make_expert_params_mapping
(
cls
,
cls
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
View file @
bfc3b3f7
...
@@ -23,6 +23,7 @@ from sglang.srt.layers.moe.token_dispatcher.mooncake import (
...
@@ -23,6 +23,7 @@ from sglang.srt.layers.moe.token_dispatcher.mooncake import (
)
)
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
StandardCombineInput
,
StandardCombineInput
,
StandardDispatcher
,
StandardDispatchOutput
,
StandardDispatchOutput
,
)
)
...
@@ -38,6 +39,7 @@ __all__ = [
...
@@ -38,6 +39,7 @@ __all__ = [
"MooncakeCombineInput"
,
"MooncakeCombineInput"
,
"MooncakeDispatchOutput"
,
"MooncakeDispatchOutput"
,
"MooncakeEPDispatcher"
,
"MooncakeEPDispatcher"
,
"StandardDispatcher"
,
"StandardDispatchOutput"
,
"StandardDispatchOutput"
,
"StandardCombineInput"
,
"StandardCombineInput"
,
"DeepEPConfig"
,
"DeepEPConfig"
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/base.py
View file @
bfc3b3f7
...
@@ -73,7 +73,7 @@ class DispatchOutputFormat(Enum):
...
@@ -73,7 +73,7 @@ class DispatchOutputFormat(Enum):
class
DispatchOutput
(
Protocol
):
class
DispatchOutput
(
Protocol
):
"""Protocol for dispatch outputs in different formats."""
"""Protocol for dispatch outputs in different formats."""
# TODO: add
hidden_states to
the protocol
hidden_states
:
to
rch
.
Tensor
@
property
@
property
def
format
(
self
)
->
DispatchOutputFormat
:
...
def
format
(
self
)
->
DispatchOutputFormat
:
...
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
bfc3b3f7
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
from
sglang.srt.layers.moe.token_dispatcher.base
import
(
from
sglang.srt.layers.moe.token_dispatcher.base
import
(
BaseDispatcher
,
BaseDispatcher
,
BaseDispatcherConfig
,
BaseDispatcherConfig
,
...
@@ -15,6 +16,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
...
@@ -15,6 +16,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutput
,
DispatchOutput
,
DispatchOutputFormat
,
DispatchOutputFormat
,
)
)
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.utils
import
(
from
sglang.srt.layers.moe.utils
import
(
DeepEPMode
,
DeepEPMode
,
get_deepep_config
,
get_deepep_config
,
...
@@ -51,8 +53,6 @@ from enum import Enum, IntEnum, auto
...
@@ -51,8 +53,6 @@ from enum import Enum, IntEnum, auto
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -61,9 +61,9 @@ logger = logging.getLogger(__name__)
...
@@ -61,9 +61,9 @@ logger = logging.getLogger(__name__)
class
DeepEPNormalOutput
(
NamedTuple
):
class
DeepEPNormalOutput
(
NamedTuple
):
"""DeepEP normal dispatch output."""
"""DeepEP normal dispatch output."""
hidden_states
:
torch
.
Tensor
|
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
hidden_states
:
torch
.
Tensor
#
hidden_states_scale
hidden_states_scale
:
Optional
[
torch
.
Tensor
]
topk_id
x
:
torch
.
Tensor
topk_id
s
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
num_recv_tokens_per_expert
:
List
[
int
]
num_recv_tokens_per_expert
:
List
[
int
]
...
@@ -75,8 +75,9 @@ class DeepEPNormalOutput(NamedTuple):
...
@@ -75,8 +75,9 @@ class DeepEPNormalOutput(NamedTuple):
class
DeepEPLLOutput
(
NamedTuple
):
class
DeepEPLLOutput
(
NamedTuple
):
"""DeepEP low latency dispatch output."""
"""DeepEP low latency dispatch output."""
hidden_states_fp8
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
hidden_states
:
torch
.
Tensor
topk_idx
:
torch
.
Tensor
hidden_states_scale
:
Optional
[
torch
.
Tensor
]
topk_ids
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
masked_m
:
torch
.
Tensor
masked_m
:
torch
.
Tensor
expected_m
:
int
expected_m
:
int
...
@@ -314,9 +315,7 @@ class _DeepEPDispatcherImplBase:
...
@@ -314,9 +315,7 @@ class _DeepEPDispatcherImplBase:
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_output
:
TopKOutput
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -326,7 +325,7 @@ class _DeepEPDispatcherImplBase:
...
@@ -326,7 +325,7 @@ class _DeepEPDispatcherImplBase:
def
combine_a
(
def
combine_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
overlap_args
:
Optional
[
"CombineOverlapArgs"
],
overlap_args
:
Optional
[
"CombineOverlapArgs"
],
):
):
...
@@ -345,15 +344,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -345,15 +344,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
self
.
async_finish
=
async_finish
self
.
async_finish
=
async_finish
self
.
src2dst
=
None
self
.
src2dst
=
None
self
.
quant_config
=
{}
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_output
:
TopKOutput
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
topk_ids
=
topk_ids
.
to
(
torch
.
int64
)
if
(
if
(
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
not
get_moe_runner_backend
().
is_cutlass
()
and
not
get_moe_runner_backend
().
is_cutlass
()
...
@@ -367,25 +366,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -367,25 +366,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
hidden_states
,
topk_id
x
,
topk_weights
,
previous_event
return
hidden_states
,
topk_id
s
,
topk_weights
,
previous_event
def
dispatch_b
(
self
,
hidden_states
,
topk_id
x
,
topk_weights
,
previous_event
):
def
dispatch_b
(
self
,
hidden_states
,
topk_id
s
,
topk_weights
,
previous_event
):
(
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
topk_weights
,
topk_weights
,
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert
,
event
,
event
,
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_id
x
,
topk_weights
,
previous_event
)
)
=
self
.
_dispatch_core
(
hidden_states
,
topk_id
s
,
topk_weights
,
previous_event
)
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
event
.
current_stream_wait
()
if
self
.
async_finish
else
()
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
,
hidden_states_scale
=
hidden_states
else
:
hidden_states_scale
=
None
return
DeepEPNormalOutput
(
return
DeepEPNormalOutput
(
hidden_states
,
topk_idx
,
topk_weights
,
num_recv_tokens_per_expert
hidden_states
,
hidden_states_scale
,
topk_ids
,
topk_weights
,
num_recv_tokens_per_expert
,
)
)
def
_dispatch_core
(
def
_dispatch_core
(
self
,
self
,
x
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
x
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
previous_event
,
previous_event
,
):
):
...
@@ -397,7 +406,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -397,7 +406,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
is_token_in_rank
,
is_token_in_rank
,
previous_event
,
previous_event
,
)
=
buffer
.
get_dispatch_layout
(
)
=
buffer
.
get_dispatch_layout
(
topk_id
x
,
topk_id
s
,
self
.
num_experts
,
self
.
num_experts
,
previous_event
=
previous_event
,
previous_event
=
previous_event
,
async_finish
=
self
.
async_finish
,
async_finish
=
self
.
async_finish
,
...
@@ -409,14 +418,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -409,14 +418,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
(
(
recv_x
,
recv_x
,
recv_topk_id
x
,
recv_topk_id
s
,
recv_topk_weights
,
recv_topk_weights
,
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert
,
self
.
handle
,
self
.
handle
,
event
,
event
,
)
=
buffer
.
dispatch
(
)
=
buffer
.
dispatch
(
x
,
x
,
topk_idx
=
topk_id
x
,
topk_idx
=
topk_id
s
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
...
@@ -437,7 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -437,7 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
return
(
return
(
recv_x
,
recv_x
,
recv_topk_id
x
,
recv_topk_id
s
,
recv_topk_weights
,
recv_topk_weights
,
num_recv_tokens_per_expert
,
num_recv_tokens_per_expert
,
event
,
event
,
...
@@ -446,40 +455,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -446,40 +455,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def
combine_a
(
def
combine_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
overlap_args
:
Optional
[
"CombineOverlapArgs"
],
overlap_args
:
Optional
[
"CombineOverlapArgs"
],
):
):
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
deepep_post_reorder_triton_kernel
,
)
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
or
_is_npu
:
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
or
_is_npu
:
output
=
hidden_states
output
=
hidden_states
else
:
else
:
if
hidden_states
.
shape
[
0
]
>
0
:
raise
NotImplementedError
()
# triton runner was supported but it's temporarily disabled
num_tokens
=
self
.
src2dst
.
shape
[
0
]
//
self
.
router_topk
output
=
torch
.
empty
(
(
num_tokens
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
deepep_post_reorder_triton_kernel
[(
num_tokens
,)](
hidden_states
,
output
,
self
.
src2dst
,
topk_idx
,
topk_weights
,
self
.
router_topk
,
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
else
:
output
=
torch
.
zeros
(
(
0
,
hidden_states
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
previous_event
=
Buffer
.
capture
()
if
self
.
async_finish
else
None
return
output
,
previous_event
return
output
,
previous_event
...
@@ -514,6 +499,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
...
@@ -514,6 +499,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
self
.
num_experts
,
self
.
num_experts
,
)
)
def
set_quant_config
(
self
,
quant_config
:
dict
):
self
.
quant_config
=
quant_config
class
_DeepEPDispatcherImplLowLatency
(
_DeepEPDispatcherImplBase
):
class
_DeepEPDispatcherImplLowLatency
(
_DeepEPDispatcherImplBase
):
def
__init__
(
self
,
return_recv_hook
:
bool
,
**
kwargs
):
def
__init__
(
self
,
return_recv_hook
:
bool
,
**
kwargs
):
...
@@ -525,28 +513,27 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -525,28 +513,27 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
"""
"""
self
.
return_recv_hook
=
return_recv_hook
self
.
return_recv_hook
=
return_recv_hook
self
.
device_module
=
torch
.
get_device_module
()
self
.
device_module
=
torch
.
get_device_module
()
self
.
quant_config
=
{}
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_output
:
TopKOutput
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
buffer
=
self
.
_get_buffer
()
buffer
=
self
.
_get_buffer
()
topk_idx
=
topk_idx
.
to
(
torch
.
int64
)
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
topk_ids
=
topk_ids
.
to
(
torch
.
int64
)
expected_m
=
(
expected_m
=
(
hidden_states
.
shape
[
0
]
*
buffer
.
group_size
*
topk_id
x
.
shape
[
1
]
hidden_states
.
shape
[
0
]
*
buffer
.
group_size
*
topk_id
s
.
shape
[
1
]
+
self
.
num_experts
+
self
.
num_experts
)
//
self
.
num_experts
)
//
self
.
num_experts
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_core
(
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_core
(
hidden_states
,
hidden_states
,
input_global_scale
,
topk_ids
,
topk_idx
,
)
)
return
(
return
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
topk_weights
,
topk_weights
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
...
@@ -557,7 +544,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -557,7 +544,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def
dispatch_b
(
def
dispatch_b
(
self
,
self
,
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
topk_weights
,
topk_weights
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
...
@@ -570,9 +557,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -570,9 +557,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m
masked_m
)
)
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
,
hidden_states_scale
=
hidden_states
else
:
hidden_states_scale
=
None
deepep_output
=
DeepEPLLOutput
(
deepep_output
=
DeepEPLLOutput
(
hidden_states
,
hidden_states
,
topk_idx
,
hidden_states_scale
,
topk_ids
,
topk_weights
,
topk_weights
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
...
@@ -582,10 +575,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -582,10 +575,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def
_dispatch_core
(
def
_dispatch_core
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_ids
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
):
):
use_nvfp4
=
use_fp8
=
False
use_nvfp4
=
use_fp8
=
False
input_global_scale
=
self
.
quant_config
.
get
(
"input_global_scale"
,
None
)
if
input_global_scale
is
not
None
:
if
input_global_scale
is
not
None
:
use_nvfp4
=
True
use_nvfp4
=
True
elif
not
get_bool_env_var
(
"SGLANG_DEEPEP_BF16_DISPATCH"
):
elif
not
get_bool_env_var
(
"SGLANG_DEEPEP_BF16_DISPATCH"
):
...
@@ -595,7 +588,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -595,7 +588,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
packed_recv_hidden
,
self
.
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
packed_recv_hidden
,
self
.
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
buffer
.
low_latency_dispatch
(
buffer
.
low_latency_dispatch
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_experts
,
self
.
num_experts
,
use_fp8
=
use_fp8
,
use_fp8
=
use_fp8
,
...
@@ -618,13 +611,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -618,13 +611,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def
combine_a
(
def
combine_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
overlap_args
:
Optional
[
"CombineOverlapArgs"
],
overlap_args
:
Optional
[
"CombineOverlapArgs"
],
):
):
hidden_states
,
event
,
hook
=
self
.
_combine_core
(
hidden_states
,
event
,
hook
=
self
.
_combine_core
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
topk_weights
,
topk_weights
,
overlap_args
=
overlap_args
,
overlap_args
=
overlap_args
,
)
)
...
@@ -644,7 +637,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -644,7 +637,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
def
_combine_core
(
def
_combine_core
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
overlap_args
:
Optional
[
"CombineOverlapArgs"
],
overlap_args
:
Optional
[
"CombineOverlapArgs"
],
):
):
...
@@ -658,7 +651,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -658,7 +651,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
with
ctx
:
with
ctx
:
combined_hidden_states
,
event
,
hook
=
buffer
.
low_latency_combine
(
combined_hidden_states
,
event
,
hook
=
buffer
.
low_latency_combine
(
x
=
hidden_states
,
x
=
hidden_states
,
topk_idx
=
topk_id
x
,
topk_idx
=
topk_id
s
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
handle
=
self
.
handle
,
handle
=
self
.
handle
,
async_finish
=
not
self
.
return_recv_hook
,
async_finish
=
not
self
.
return_recv_hook
,
...
@@ -688,6 +681,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -688,6 +681,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self
.
num_experts
,
self
.
num_experts
,
)
)
def
set_quant_config
(
self
,
quant_config
:
dict
):
self
.
quant_config
=
quant_config
@
dataclass
@
dataclass
class
_Stage
(
Enum
):
class
_Stage
(
Enum
):
...
@@ -745,25 +741,20 @@ class DeepEPDispatcher(BaseDispatcher):
...
@@ -745,25 +741,20 @@ class DeepEPDispatcher(BaseDispatcher):
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_output
:
TopKOutput
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
):
self
.
_update_stage
(
_Stage
.
INITIAL
,
_Stage
.
AFTER_DISPATCH_A
)
self
.
_update_stage
(
_Stage
.
INITIAL
,
_Stage
.
AFTER_DISPATCH_A
)
inner_state
=
self
.
_get_impl
(
forward_batch
).
dispatch_a
(
inner_state
=
self
.
_get_impl
().
dispatch_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
input_global_scale
=
input_global_scale
,
topk_output
=
topk_output
,
topk_idx
=
topk_idx
,
topk_weights
=
topk_weights
,
)
)
self
.
_dispatch_intermediate_state
=
forward_batch
,
inner_state
self
.
_dispatch_intermediate_state
=
inner_state
def
dispatch_b
(
self
):
def
dispatch_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_A
,
_Stage
.
AFTER_DISPATCH_B
)
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_A
,
_Stage
.
AFTER_DISPATCH_B
)
forward_batch
,
inner_state
=
self
.
_dispatch_intermediate_state
inner_state
=
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
return
self
.
_get_impl
(
forward_batch
).
dispatch_b
(
*
inner_state
)
return
self
.
_get_impl
().
dispatch_b
(
*
inner_state
)
def
combine
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
def
combine
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
self
.
combine_a
(
*
args
,
**
kwargs
)
self
.
combine_a
(
*
args
,
**
kwargs
)
...
@@ -773,30 +764,28 @@ class DeepEPDispatcher(BaseDispatcher):
...
@@ -773,30 +764,28 @@ class DeepEPDispatcher(BaseDispatcher):
def
combine_a
(
def
combine_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
overlap_args
:
Optional
[
"CombineOverlapArgs"
]
=
None
,
overlap_args
:
Optional
[
"CombineOverlapArgs"
]
=
None
,
):
):
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
inner_state
=
self
.
_get_impl
(
forward_batch
).
combine_a
(
inner_state
=
self
.
_get_impl
().
combine_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_id
x
=
topk_id
x
,
topk_id
s
=
topk_id
s
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
overlap_args
=
overlap_args
,
overlap_args
=
overlap_args
,
)
)
self
.
_combine_intermediate_state
=
forward_batch
,
inner_state
self
.
_combine_intermediate_state
=
inner_state
def
combine_b
(
self
):
def
combine_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_COMBINE_A
,
_Stage
.
INITIAL
)
self
.
_update_stage
(
_Stage
.
AFTER_COMBINE_A
,
_Stage
.
INITIAL
)
forward_batch
,
inner_state
=
self
.
_combine_intermediate_state
inner_state
=
self
.
_combine_intermediate_state
del
self
.
_combine_intermediate_state
del
self
.
_combine_intermediate_state
return
self
.
_get_impl
(
forward_batch
).
combine_b
(
*
inner_state
)
return
self
.
_get_impl
().
combine_b
(
*
inner_state
)
def
_get_impl
(
self
,
forward_batch
:
ForwardBatch
)
->
_DeepEPDispatcherImplBase
:
def
_get_impl
(
self
)
->
_DeepEPDispatcherImplBase
:
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
is_extend_in_batch
=
get_is_extend_in_batch
()
forward_batch
.
is_extend_in_batch
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
is_extend_in_batch
)
)
if
resolved_deepep_mode
==
DeepEPMode
.
NORMAL
:
if
resolved_deepep_mode
==
DeepEPMode
.
NORMAL
:
return
self
.
_normal_dispatcher
return
self
.
_normal_dispatcher
elif
resolved_deepep_mode
==
DeepEPMode
.
LOW_LATENCY
:
elif
resolved_deepep_mode
==
DeepEPMode
.
LOW_LATENCY
:
...
@@ -807,3 +796,9 @@ class DeepEPDispatcher(BaseDispatcher):
...
@@ -807,3 +796,9 @@ class DeepEPDispatcher(BaseDispatcher):
def
_update_stage
(
self
,
old_stage
,
new_stage
):
def
_update_stage
(
self
,
old_stage
,
new_stage
):
assert
self
.
_stage
==
old_stage
assert
self
.
_stage
==
old_stage
self
.
_stage
=
new_stage
self
.
_stage
=
new_stage
def
set_quant_config
(
self
,
quant_config
:
dict
):
if
self
.
deepep_mode
.
enable_low_latency
():
self
.
_low_latency_dispatcher
.
set_quant_config
(
quant_config
)
if
self
.
deepep_mode
.
enable_normal
():
self
.
_normal_dispatcher
.
set_quant_config
(
quant_config
)
python/sglang/srt/layers/moe/token_dispatcher/mooncake.py
View file @
bfc3b3f7
...
@@ -5,6 +5,7 @@ from dataclasses import dataclass
...
@@ -5,6 +5,7 @@ from dataclasses import dataclass
from
typing
import
NamedTuple
,
Optional
,
Tuple
from
typing
import
NamedTuple
,
Optional
,
Tuple
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.layers.dp_attention
import
get_is_extend_in_batch
from
sglang.srt.layers.moe.token_dispatcher.base
import
(
from
sglang.srt.layers.moe.token_dispatcher.base
import
(
BaseDispatcher
,
BaseDispatcher
,
CombineInput
,
CombineInput
,
...
@@ -12,6 +13,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
...
@@ -12,6 +13,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutput
,
DispatchOutput
,
DispatchOutputFormat
,
DispatchOutputFormat
,
)
)
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.utils
import
DeepEPMode
from
sglang.srt.layers.moe.utils
import
DeepEPMode
from
sglang.srt.utils
import
get_int_env_var
from
sglang.srt.utils
import
get_int_env_var
...
@@ -27,16 +29,15 @@ from enum import Enum, auto
...
@@ -27,16 +29,15 @@ from enum import Enum, auto
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
MooncakeDispatchOutput
(
NamedTuple
):
class
MooncakeDispatchOutput
(
NamedTuple
):
"""Mooncake EP dispatch output."""
"""Mooncake EP dispatch output."""
hidden_states_fp8
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
hidden_states
:
torch
.
Tensor
topk_idx
:
torch
.
Tensor
hidden_states_scale
:
torch
.
Tensor
topk_ids
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
masked_m
:
torch
.
Tensor
masked_m
:
torch
.
Tensor
expected_m
:
int
expected_m
:
int
...
@@ -164,23 +165,23 @@ class _MooncakeEPDispatcherImpl:
...
@@ -164,23 +165,23 @@ class _MooncakeEPDispatcherImpl:
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
topk_weights
:
torch
.
Tensor
,
):
):
topk_ids
,
topk_weights
=
topk_output
.
topk_ids
,
topk_output
.
topk_weights
buffer
=
self
.
_get_buffer
()
buffer
=
self
.
_get_buffer
()
topk_id
x
=
topk_id
x
.
to
(
torch
.
int64
)
topk_id
s
=
topk_id
s
.
to
(
torch
.
int64
)
expected_m
=
(
expected_m
=
(
hidden_states
.
shape
[
0
]
*
buffer
.
group_size
*
topk_id
x
.
shape
[
1
]
hidden_states
.
shape
[
0
]
*
buffer
.
group_size
*
topk_id
s
.
shape
[
1
]
+
self
.
num_experts
+
self
.
num_experts
)
//
self
.
num_experts
)
//
self
.
num_experts
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_core
(
hidden_states
,
masked_m
,
event
,
hook
=
self
.
_dispatch_core
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
use_fp8
=
True
,
use_fp8
=
True
,
)
)
return
(
return
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
topk_weights
,
topk_weights
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
...
@@ -191,7 +192,7 @@ class _MooncakeEPDispatcherImpl:
...
@@ -191,7 +192,7 @@ class _MooncakeEPDispatcherImpl:
def
dispatch_b
(
def
dispatch_b
(
self
,
self
,
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
topk_weights
,
topk_weights
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
...
@@ -206,7 +207,7 @@ class _MooncakeEPDispatcherImpl:
...
@@ -206,7 +207,7 @@ class _MooncakeEPDispatcherImpl:
return
MooncakeDispatchOutput
(
return
MooncakeDispatchOutput
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
topk_weights
,
topk_weights
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
...
@@ -215,14 +216,14 @@ class _MooncakeEPDispatcherImpl:
...
@@ -215,14 +216,14 @@ class _MooncakeEPDispatcherImpl:
def
_dispatch_core
(
def
_dispatch_core
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
use_fp8
:
bool
=
False
,
use_fp8
:
bool
=
False
,
):
):
buffer
=
self
.
_get_buffer
()
buffer
=
self
.
_get_buffer
()
packed_recv_hidden
,
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
packed_recv_hidden
,
packed_recv_count
,
self
.
handle
,
event
,
hook
=
(
buffer
.
dispatch
(
buffer
.
dispatch
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
self
.
active_ranks
,
self
.
active_ranks
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_max_dispatch_tokens_per_rank
,
self
.
num_experts
,
self
.
num_experts
,
...
@@ -237,12 +238,12 @@ class _MooncakeEPDispatcherImpl:
...
@@ -237,12 +238,12 @@ class _MooncakeEPDispatcherImpl:
def
combine_a
(
def
combine_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
hidden_states
,
event
,
hook
=
self
.
_combine_core
(
hidden_states
,
event
,
hook
=
self
.
_combine_core
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
topk_weights
,
topk_weights
,
)
)
return
hidden_states
,
event
,
hook
return
hidden_states
,
event
,
hook
...
@@ -254,13 +255,13 @@ class _MooncakeEPDispatcherImpl:
...
@@ -254,13 +255,13 @@ class _MooncakeEPDispatcherImpl:
def
_combine_core
(
def
_combine_core
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
):
buffer
=
self
.
_get_buffer
()
buffer
=
self
.
_get_buffer
()
combined_hidden_states
,
event
,
hook
=
buffer
.
combine
(
combined_hidden_states
,
event
,
hook
=
buffer
.
combine
(
hidden_states
,
hidden_states
,
topk_id
x
,
topk_id
s
,
topk_weights
,
topk_weights
,
self
.
active_ranks
,
self
.
active_ranks
,
-
1
if
self
.
first_execution
else
self
.
timeout_us
,
-
1
if
self
.
first_execution
else
self
.
timeout_us
,
...
@@ -332,24 +333,20 @@ class MooncakeEPDispatcher(BaseDispatcher):
...
@@ -332,24 +333,20 @@ class MooncakeEPDispatcher(BaseDispatcher):
def
dispatch_a
(
def
dispatch_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_global_scale
:
Optional
[
torch
.
Tensor
],
topk_output
:
TopKOutput
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
):
self
.
_update_stage
(
_Stage
.
INITIAL
,
_Stage
.
AFTER_DISPATCH_A
)
self
.
_update_stage
(
_Stage
.
INITIAL
,
_Stage
.
AFTER_DISPATCH_A
)
inner_state
=
self
.
_get_impl
(
forward_batch
).
dispatch_a
(
inner_state
=
self
.
_get_impl
().
dispatch_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_output
=
topk_output
,
topk_weights
=
topk_weights
,
)
)
self
.
_dispatch_intermediate_state
=
forward_batch
,
inner_state
self
.
_dispatch_intermediate_state
=
inner_state
def
dispatch_b
(
self
):
def
dispatch_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_A
,
_Stage
.
AFTER_DISPATCH_B
)
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_A
,
_Stage
.
AFTER_DISPATCH_B
)
forward_batch
,
inner_state
=
self
.
_dispatch_intermediate_state
inner_state
=
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
del
self
.
_dispatch_intermediate_state
return
self
.
_get_impl
(
forward_batch
).
dispatch_b
(
*
inner_state
)
return
self
.
_get_impl
().
dispatch_b
(
*
inner_state
)
def
combine
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
def
combine
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
self
.
combine_a
(
*
args
,
**
kwargs
)
self
.
combine_a
(
*
args
,
**
kwargs
)
...
@@ -359,29 +356,27 @@ class MooncakeEPDispatcher(BaseDispatcher):
...
@@ -359,29 +356,27 @@ class MooncakeEPDispatcher(BaseDispatcher):
def
combine_a
(
def
combine_a
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_id
x
:
torch
.
Tensor
,
topk_id
s
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
overlap_args
:
Optional
=
None
,
overlap_args
:
Optional
=
None
,
):
):
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
inner_state
=
self
.
_get_impl
(
forward_batch
).
combine_a
(
inner_state
=
self
.
_get_impl
().
combine_a
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_id
x
=
topk_id
x
,
topk_id
s
=
topk_id
s
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
)
)
self
.
_combine_intermediate_state
=
forward_batch
,
inner_state
self
.
_combine_intermediate_state
=
inner_state
def
combine_b
(
self
):
def
combine_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_COMBINE_A
,
_Stage
.
INITIAL
)
self
.
_update_stage
(
_Stage
.
AFTER_COMBINE_A
,
_Stage
.
INITIAL
)
forward_batch
,
inner_state
=
self
.
_combine_intermediate_state
inner_state
=
self
.
_combine_intermediate_state
del
self
.
_combine_intermediate_state
del
self
.
_combine_intermediate_state
return
self
.
_get_impl
(
forward_batch
).
combine_b
(
*
inner_state
)
return
self
.
_get_impl
().
combine_b
(
*
inner_state
)
def
_get_impl
(
self
,
forward_batch
:
ForwardBatch
)
->
_MooncakeEPDispatcherImpl
:
def
_get_impl
(
self
)
->
_MooncakeEPDispatcherImpl
:
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
is_extend_in_batch
=
get_is_extend_in_batch
()
forward_batch
.
is_extend_in_batch
resolved_deepep_mode
=
self
.
deepep_mode
.
resolve
(
is_extend_in_batch
)
)
if
resolved_deepep_mode
==
DeepEPMode
.
NORMAL
:
if
resolved_deepep_mode
==
DeepEPMode
.
NORMAL
:
raise
NotImplementedError
raise
NotImplementedError
elif
resolved_deepep_mode
==
DeepEPMode
.
LOW_LATENCY
:
elif
resolved_deepep_mode
==
DeepEPMode
.
LOW_LATENCY
:
...
@@ -392,3 +387,6 @@ class MooncakeEPDispatcher(BaseDispatcher):
...
@@ -392,3 +387,6 @@ class MooncakeEPDispatcher(BaseDispatcher):
def
_update_stage
(
self
,
old_stage
,
new_stage
):
def
_update_stage
(
self
,
old_stage
,
new_stage
):
assert
self
.
_stage
==
old_stage
assert
self
.
_stage
==
old_stage
self
.
_stage
=
new_stage
self
.
_stage
=
new_stage
def
set_quant_config
(
self
,
quant_config
:
dict
):
pass
python/sglang/srt/layers/moe/token_dispatcher/standard.py
View file @
bfc3b3f7
...
@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple
...
@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple
import
torch
import
torch
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_world_size
,
)
from
sglang.srt.layers.moe.moe_runner.base
import
MoeRunnerConfig
from
sglang.srt.layers.moe.token_dispatcher.base
import
(
from
sglang.srt.layers.moe.token_dispatcher.base
import
(
BaseDispatcher
,
BaseDispatcher
,
CombineInput
,
CombineInput
,
...
@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
...
@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
DispatchOutput
,
DispatchOutput
,
DispatchOutputFormat
,
DispatchOutputFormat
,
)
)
from
sglang.srt.layers.moe.topk
import
TopKOutput
,
TopKOutputChecker
from
sglang.srt.layers.moe.utils
import
get_moe_runner_backend
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
TopKOutput
...
@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput)
...
@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput)
class
StandardDispatcher
(
BaseDispatcher
):
class
StandardDispatcher
(
BaseDispatcher
):
def
__init__
(
self
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_ep_size
=
get_moe_expert_parallel_world_size
()
self
.
enable_flashinfer_cutlass_moe
=
(
get_moe_runner_backend
().
is_flashinfer_cutlass
()
)
self
.
num_experts
=
moe_runner_config
.
num_experts
self
.
num_local_experts
=
moe_runner_config
.
num_local_experts
self
.
moe_ep_rank
=
get_moe_expert_parallel_rank
()
self
.
local_expert_mapping
=
None
def
dispatch
(
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
)
->
DispatchOutput
:
)
->
DispatchOutput
:
if
(
self
.
moe_ep_size
>
1
and
not
self
.
enable_flashinfer_cutlass_moe
and
TopKOutputChecker
.
format_is_standard
(
topk_output
)
):
if
self
.
local_expert_mapping
is
None
:
self
.
local_expert_mapping
=
torch
.
full
(
(
self
.
num_experts
,),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
local_expert_mapping
[
self
.
moe_ep_rank
*
self
.
num_local_experts
:
(
self
.
moe_ep_rank
+
1
)
*
self
.
num_local_experts
]
=
torch
.
arange
(
0
,
self
.
num_local_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
self
.
local_expert_mapping
is
not
None
:
if
TopKOutputChecker
.
format_is_standard
(
topk_output
):
topk_output
=
topk_output
.
_replace
(
topk_ids
=
self
.
local_expert_mapping
[
topk_output
.
topk_ids
]
)
elif
TopKOutputChecker
.
format_is_triton_kernel
(
topk_output
):
raise
NotImplementedError
()
return
StandardDispatchOutput
(
return
StandardDispatchOutput
(
hidden_states
=
hidden_states
,
topk_output
=
topk_output
hidden_states
=
hidden_states
,
topk_output
=
topk_output
)
)
...
@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher):
...
@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher):
# TODO: this branch should be removed in the future
# TODO: this branch should be removed in the future
assert
isinstance
(
combine_input
,
torch
.
Tensor
)
assert
isinstance
(
combine_input
,
torch
.
Tensor
)
return
combine_input
return
combine_input
def
set_quant_config
(
self
,
quant_config
:
dict
):
pass
python/sglang/srt/layers/moe/topk.py
View file @
bfc3b3f7
...
@@ -365,9 +365,10 @@ class TopK(CustomOp):
...
@@ -365,9 +365,10 @@ class TopK(CustomOp):
def
empty_topk_output
(
self
,
device
:
torch
.
device
)
->
TopKOutput
:
def
empty_topk_output
(
self
,
device
:
torch
.
device
)
->
TopKOutput
:
topk
=
self
.
topk_config
.
top_k
-
self
.
topk_config
.
num_fused_shared_experts
topk
=
self
.
topk_config
.
top_k
-
self
.
topk_config
.
num_fused_shared_experts
topk_weights
=
torch
.
empty
((
0
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
topk_weights
=
torch
.
empty
((
0
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
topk_idx
=
torch
.
full
((
0
,
topk
),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
topk_ids
=
torch
.
full
((
0
,
topk
),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
# FIXME: router_logits should be of size (0, num_experts)
router_logits
=
torch
.
empty
((
0
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
router_logits
=
torch
.
empty
((
0
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
return
StandardTopKOutput
(
topk_weights
,
topk_id
x
,
router_logits
)
return
StandardTopKOutput
(
topk_weights
,
topk_id
s
,
router_logits
)
# ------------------------------- TopK implementation -------------------------------------
# ------------------------------- TopK implementation -------------------------------------
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
bfc3b3f7
...
@@ -1244,6 +1244,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1244,6 +1244,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
(
1
/
w2_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
(
1
/
w2_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
)
layer
.
dispatcher
.
set_quant_config
(
{
"input_global_scale"
:
layer
.
w13_input_scale_quant
}
)
# Validate weight scales
# Validate weight scales
for
name
,
weight_scale
in
[
for
name
,
weight_scale
in
[
(
"w13"
,
layer
.
w13_weight_scale
),
(
"w13"
,
layer
.
w13_weight_scale
),
...
...
python/sglang/srt/layers/quantization/w4afp8.py
View file @
bfc3b3f7
...
@@ -339,7 +339,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -339,7 +339,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
hidden_states
,
topk_idx
,
topk_weights
=
(
hidden_states
,
topk_idx
,
topk_weights
=
(
dispatch_output
.
hidden_states
,
dispatch_output
.
hidden_states
,
dispatch_output
.
topk_id
x
,
dispatch_output
.
topk_id
s
,
dispatch_output
.
topk_weights
,
dispatch_output
.
topk_weights
,
)
)
if
isinstance
(
hidden_states
,
tuple
):
if
isinstance
(
hidden_states
,
tuple
):
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
bfc3b3f7
...
@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
set_dp_buffer_len
,
set_dp_buffer_len
,
set_is_extend_in_batch
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.token_dispatcher.deepep
import
DeepEPBuffer
from
sglang.srt.layers.moe.token_dispatcher.deepep
import
DeepEPBuffer
...
@@ -639,6 +640,7 @@ class CudaGraphRunner:
...
@@ -639,6 +640,7 @@ class CudaGraphRunner:
# Clean intermediate result cache for DP attention
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
set_is_extend_in_batch
(
False
)
kwargs
=
{}
kwargs
=
{}
if
(
if
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
bfc3b3f7
...
@@ -44,6 +44,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -44,6 +44,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank
,
get_attention_dp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
set_dp_buffer_len
,
set_dp_buffer_len
,
set_is_extend_in_batch
,
)
)
from
sglang.srt.utils
import
get_compiler_backend
,
is_npu
,
support_triton
from
sglang.srt.utils
import
get_compiler_backend
,
is_npu
,
support_triton
...
@@ -688,6 +689,7 @@ class ForwardBatch:
...
@@ -688,6 +689,7 @@ class ForwardBatch:
self
.
global_dp_buffer_len
=
buffer_len
self
.
global_dp_buffer_len
=
buffer_len
set_dp_buffer_len
(
buffer_len
,
num_tokens
,
global_num_tokens
)
set_dp_buffer_len
(
buffer_len
,
num_tokens
,
global_num_tokens
)
set_is_extend_in_batch
(
self
.
is_extend_in_batch
)
bs
=
self
.
batch_size
bs
=
self
.
batch_size
...
...
python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
View file @
bfc3b3f7
...
@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
set_dp_buffer_len
,
set_dp_buffer_len
,
set_is_extend_in_batch
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
...
@@ -377,6 +378,9 @@ class PiecewiseCudaGraphRunner:
...
@@ -377,6 +378,9 @@ class PiecewiseCudaGraphRunner:
# Clean intermediate result cache for DP attention
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
set_dp_buffer_len
(
global_dp_buffer_len
,
num_tokens
)
# FIXME: the implementation is hacky. `is_extend_in_batch`` is for determining the deepep mode.
# It is True in this context but we need to set it to use low latency deepep mode.
set_is_extend_in_batch
(
False
)
kwargs
=
{}
kwargs
=
{}
with
set_forward_context
(
forward_batch
,
self
.
attention_layers
):
with
set_forward_context
(
forward_batch
,
self
.
attention_layers
):
...
...
python/sglang/srt/models/bailing_moe.py
View file @
bfc3b3f7
...
@@ -380,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module):
...
@@ -380,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module):
if
self
.
num_shared_experts
>
0
:
if
self
.
num_shared_experts
>
0
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
topk_
weights
,
topk_idx
,
_
=
self
.
topk
(
topk_
output
=
self
.
topk
(
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
...
@@ -389,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module):
...
@@ -389,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module):
),
),
)
)
else
:
else
:
topk_idx
=
torch
.
full
(
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
)
topk_weights
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
if
self
.
ep_size
>
1
:
(
hidden_states
,
topk_idx
,
topk_weights
,
reorder_topk_ids
,
num_recv_tokens_per_expert
,
seg_indptr
,
masked_m
,
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch
(
hidden_states
,
topk_idx
,
topk_weights
,
forward_batch
=
forward_batch
,
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_output
=
topk_output
,
topk_weights
=
topk_weights
,
reorder_topk_ids
=
reorder_topk_ids
,
seg_indptr
=
seg_indptr
,
masked_m
=
masked_m
,
expected_m
=
expected_m
,
num_recv_tokens_per_expert
=
num_recv_tokens_per_expert
,
forward_batch
=
forward_batch
,
)
if
self
.
ep_size
>
1
:
final_hidden_states
=
self
.
deepep_dispatcher
.
combine
(
final_hidden_states
,
topk_idx
,
topk_weights
,
forward_batch
=
forward_batch
,
)
)
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
+
=
shared_output
return
final_hidden_states
return
final_hidden_states
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
bfc3b3f7
...
@@ -74,7 +74,6 @@ from sglang.srt.layers.linear import (
...
@@ -74,7 +74,6 @@ from sglang.srt.layers.linear import (
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
(
from
sglang.srt.layers.moe
import
(
get_deepep_mode
,
get_moe_a2a_backend
,
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_trtllm_moe
,
should_use_flashinfer_trtllm_moe
,
...
@@ -112,10 +111,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
...
@@ -112,10 +111,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.single_batch_overlap
import
SboFlags
from
sglang.srt.single_batch_overlap
import
SboFlags
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.two_batch_overlap
import
(
from
sglang.srt.two_batch_overlap
import
model_forward_maybe_tbo
MaybeTboDeepEPDispatcher
,
model_forward_maybe_tbo
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
BumpAllocator
,
BumpAllocator
,
LazyValue
,
LazyValue
,
...
@@ -649,19 +645,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -649,19 +645,6 @@ class DeepseekV2MoE(nn.Module):
else
None
else
None
)
)
self
.
deepep_dispatcher
=
MaybeTboDeepEPDispatcher
(
group
=
parallel_state
.
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
num_experts
=
self
.
num_experts
,
num_local_experts
=
config
.
n_routed_experts
//
self
.
tp_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
get_deepep_mode
(),
async_finish
=
True
,
return_recv_hook
=
True
,
)
self
.
_enable_a2a_moe
=
(
self
.
_enable_a2a_moe
=
(
get_moe_a2a_backend
().
is_deepep
()
or
get_moe_a2a_backend
().
is_mooncake
()
get_moe_a2a_backend
().
is_deepep
()
or
get_moe_a2a_backend
().
is_mooncake
()
)
)
...
@@ -874,7 +857,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -874,7 +857,7 @@ class DeepseekV2MoE(nn.Module):
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
if
not
self
.
_fuse_shared_experts_inside_sbo
:
if
not
self
.
_fuse_shared_experts_inside_sbo
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
topk_
weights
,
topk_idx
,
_
=
self
.
topk
(
topk_
output
=
self
.
topk
(
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
...
@@ -883,9 +866,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -883,9 +866,7 @@ class DeepseekV2MoE(nn.Module):
),
),
)
)
else
:
else
:
topk_weights
,
topk_idx
,
_
=
self
.
topk
.
empty_topk_output
(
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
hidden_states
.
device
)
if
self
.
_fuse_shared_experts_inside_sbo
:
if
self
.
_fuse_shared_experts_inside_sbo
:
shared_output
=
None
shared_output
=
None
...
@@ -896,9 +877,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -896,9 +877,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_output
=
topk_output
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
**
(
**
(
dict
(
dict
(
forward_shared_experts
=
_forward_shared_experts_and_put_results
,
forward_shared_experts
=
_forward_shared_experts_and_put_results
,
...
@@ -960,7 +939,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -960,7 +939,7 @@ class DeepseekV2MoE(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
self
.
layer_id
):
):
state
.
topk_
weights_local
,
state
.
topk_idx_local
,
_
=
self
.
topk
(
state
.
topk_
output
=
self
.
topk
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
...
@@ -969,21 +948,13 @@ class DeepseekV2MoE(nn.Module):
...
@@ -969,21 +948,13 @@ class DeepseekV2MoE(nn.Module):
),
),
)
)
else
:
else
:
state
.
topk_idx_local
=
torch
.
full
(
state
.
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
)
state
.
topk_weights_local
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
def
op_dispatch_a
(
self
,
state
):
def
op_dispatch_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
self
.
experts
.
deepep_
dispatcher
.
dispatch_a
(
self
.
experts
.
dispatcher
.
dispatch_a
(
hidden_states
=
state
.
hidden_states_mlp_input
,
hidden_states
=
state
.
hidden_states_mlp_input
,
input_global_scale
=
None
,
topk_output
=
state
.
pop
(
"topk_output"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
forward_batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
...
@@ -992,33 +963,30 @@ class DeepseekV2MoE(nn.Module):
...
@@ -992,33 +963,30 @@ class DeepseekV2MoE(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
self
.
layer_id
):
):
state
.
dispatch_output
=
self
.
experts
.
deepep_
dispatcher
.
dispatch_b
(
state
.
dispatch_output
=
self
.
experts
.
dispatcher
.
dispatch_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
def
op_experts
(
self
,
state
):
def
op_experts
(
self
,
state
):
state
.
hidden_states_experts_output
=
self
.
experts
.
moe_impl
(
state
.
hidden_states_experts_output
=
self
.
experts
.
run_moe_core
(
dispatch_output
=
state
.
dispatch_output
,
dispatch_output
=
state
.
dispatch_output
,
)
)
def
op_combine_a
(
self
,
state
):
def
op_combine_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
self
.
experts
.
deepep_
dispatcher
.
combine_a
(
self
.
experts
.
dispatcher
.
combine_a
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_id
x
=
state
.
dispatch_output
.
topk_id
x
,
topk_id
s
=
state
.
dispatch_output
.
topk_id
s
,
topk_weights
=
state
.
dispatch_output
.
topk_weights
,
topk_weights
=
state
.
dispatch_output
.
topk_weights
,
forward_batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
state
.
pop
(
"dispatch_output"
)
state
.
pop
(
"dispatch_output"
)
def
op_combine_b
(
self
,
state
):
def
op_combine_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
state
.
hidden_states_after_combine
=
(
state
.
hidden_states_after_combine
=
self
.
experts
.
dispatcher
.
combine_b
(
self
.
experts
.
deepep_dispatcher
.
combine_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
)
def
op_output
(
self
,
state
):
def
op_output
(
self
,
state
):
final_hidden_states
=
state
.
pop
(
"hidden_states_after_combine"
)
final_hidden_states
=
state
.
pop
(
"hidden_states_after_combine"
)
...
...
python/sglang/srt/models/glm4_moe.py
View file @
bfc3b3f7
...
@@ -27,7 +27,6 @@ from sglang.srt.distributed import (
...
@@ -27,7 +27,6 @@ from sglang.srt.distributed import (
get_pp_group
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
parallel_state
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
...
@@ -49,7 +48,7 @@ from sglang.srt.layers.linear import (
...
@@ -49,7 +48,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_deepep_mode
,
get_moe_a2a_backend
from
sglang.srt.layers.moe
import
get_moe_a2a_backend
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
...
@@ -71,7 +70,6 @@ from sglang.srt.models.deepseek_v2 import (
...
@@ -71,7 +70,6 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2MoE
,
DeepseekV2MoE
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
BumpAllocator
,
BumpAllocator
,
LazyValue
,
LazyValue
,
...
@@ -477,19 +475,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -477,19 +475,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
else
None
else
None
)
)
self
.
deepep_dispatcher
=
MaybeTboDeepEPDispatcher
(
group
=
parallel_state
.
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
num_experts
=
self
.
num_experts
,
num_local_experts
=
config
.
n_routed_experts
//
self
.
tp_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
get_deepep_mode
(),
async_finish
=
True
,
return_recv_hook
=
True
,
)
self
.
_enable_a2a_moe
=
(
self
.
_enable_a2a_moe
=
(
get_moe_a2a_backend
().
is_deepep
()
or
get_moe_a2a_backend
().
is_mooncake
()
get_moe_a2a_backend
().
is_deepep
()
or
get_moe_a2a_backend
().
is_mooncake
()
)
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
bfc3b3f7
...
@@ -219,7 +219,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -219,7 +219,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
topk_
weights
,
topk_idx
,
_
=
self
.
topk
(
topk_
output
=
self
.
topk
(
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
...
@@ -228,14 +228,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -228,14 +228,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
),
),
)
)
else
:
else
:
topk_weights
,
topk_idx
,
_
=
self
.
topk
.
empty_topk_output
(
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
hidden_states
.
device
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_output
=
topk_output
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
)
)
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
bfc3b3f7
...
@@ -180,7 +180,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -180,7 +180,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
if
hidden_states
.
shape
[
0
]
>
0
:
if
hidden_states
.
shape
[
0
]
>
0
:
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
topk_
weights
,
topk_idx
,
_
=
self
.
topk
(
topk_
output
=
self
.
topk
(
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
num_token_non_padded
=
forward_batch
.
num_token_non_padded
,
...
@@ -189,17 +189,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -189,17 +189,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
),
),
)
)
else
:
else
:
topk_idx
=
torch
.
full
(
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
)
topk_weights
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
topk_idx
=
topk_idx
,
topk_output
=
topk_output
,
topk_weights
=
topk_weights
,
forward_batch
=
forward_batch
,
)
)
return
final_hidden_states
return
final_hidden_states
...
@@ -219,7 +212,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -219,7 +212,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
self
.
layer_id
):
):
state
.
topk_
weights_local
,
state
.
topk_idx_local
,
_
=
self
.
topk
(
state
.
topk_
output
=
self
.
topk
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
...
@@ -228,20 +221,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -228,20 +221,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
),
),
)
)
else
:
else
:
state
.
topk_idx_local
=
torch
.
full
(
state
.
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
)
state
.
topk_weights_local
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
def
op_dispatch_a
(
self
,
state
):
def
op_dispatch_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
self
.
experts
.
deepep_
dispatcher
.
dispatch_a
(
self
.
experts
.
dispatcher
.
dispatch_a
(
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_output
=
state
.
pop
(
"topk_output"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
forward_batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
...
@@ -250,33 +236,30 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -250,33 +236,30 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
self
.
layer_id
):
):
state
.
dispatch_output
=
self
.
experts
.
deepep_
dispatcher
.
dispatch_b
(
state
.
dispatch_output
=
self
.
experts
.
dispatcher
.
dispatch_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
def
op_experts
(
self
,
state
):
def
op_experts
(
self
,
state
):
state
.
hidden_states_experts_output
=
self
.
experts
.
moe_impl
(
state
.
hidden_states_experts_output
=
self
.
experts
.
run_moe_core
(
dispatch_output
=
state
.
dispatch_output
,
dispatch_output
=
state
.
dispatch_output
,
)
)
def
op_combine_a
(
self
,
state
):
def
op_combine_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
self
.
experts
.
deepep_
dispatcher
.
combine_a
(
self
.
experts
.
dispatcher
.
combine_a
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_id
x
=
state
.
dispatch_output
.
topk_id
x
,
topk_id
s
=
state
.
dispatch_output
.
topk_id
s
,
topk_weights
=
state
.
dispatch_output
.
topk_weights
,
topk_weights
=
state
.
dispatch_output
.
topk_weights
,
forward_batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
state
.
pop
(
"dispatch_output"
)
state
.
pop
(
"dispatch_output"
)
def
op_combine_b
(
self
,
state
):
def
op_combine_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
state
.
hidden_states_after_combine
=
(
state
.
hidden_states_after_combine
=
self
.
experts
.
dispatcher
.
combine_b
(
self
.
experts
.
deepep_dispatcher
.
combine_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
)
def
op_output
(
self
,
state
):
def
op_output
(
self
,
state
):
state
.
hidden_states_mlp_output
=
state
.
pop
(
"hidden_states_after_combine"
)
state
.
hidden_states_mlp_output
=
state
.
pop
(
"hidden_states_after_combine"
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment