Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
64994980
Unverified
Commit
64994980
authored
Oct 24, 2025
by
Cheng Wan
Committed by
GitHub
Oct 24, 2025
Browse files
[10/N] MoE Refactor: reorganize deepgemm runner in DeepEPMoE (#12054)
parent
729b2429
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
397 additions
and
349 deletions
+397
-349
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+69
-280
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+1
-1
python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
+287
-22
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
+4
-4
python/sglang/srt/layers/moe/token_dispatcher/base.py
python/sglang/srt/layers/moe/token_dispatcher/base.py
+5
-5
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+18
-14
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+3
-4
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+4
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-4
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+2
-4
python/sglang/srt/single_batch_overlap.py
python/sglang/srt/single_batch_overlap.py
+2
-7
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
64994980
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
Union
import
torch
...
...
@@ -13,29 +13,23 @@ from sglang.srt.layers.moe import (
get_moe_runner_backend
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
ep_gather
,
ep_scatter
,
silu_and_mul_masked_post_quant_fwd
,
tma_align_input_scale
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.token_dispatcher.deepep
import
(
DeepEPLLCombineInput
,
DeepEPNormalCombineInput
,
)
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
sglang_per_token_group_quant_fp8
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
,
W4AFp8MoEMethod
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
from
sglang.srt.utils.offloader
import
get_offloader
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_npu
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher
import
(
DeepEPLLOutput
,
DeepEPNormalOutput
,
DeepEPLL
Dispatch
Output
,
DeepEPNormal
Dispatch
Output
,
DispatchOutput
,
)
...
...
@@ -45,7 +39,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
not
(
_is_npu
or
_is_hip
):
from
sgl_kernel
import
silu_and_mul
pass
if
_use_aiter
:
from
aiter
import
ActivationType
,
QuantType
...
...
@@ -90,6 +84,18 @@ class DeepEPMoE(FusedMoE):
routed_scaling_factor
=
routed_scaling_factor
,
)
if
_use_aiter
or
_is_npu
:
self
.
deprecate_flag
=
False
elif
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
isinstance
(
quant_config
,
Fp8Config
):
self
.
deprecate_flag
=
True
else
:
self
.
deprecate_flag
=
False
if
self
.
deprecate_flag
:
return
if
isinstance
(
quant_config
,
Fp8Config
):
self
.
use_block_quant
=
getattr
(
self
.
quant_method
,
"block_quant"
,
False
)
self
.
use_fp8_w8a8
=
True
...
...
@@ -152,6 +158,14 @@ class DeepEPMoE(FusedMoE):
disable_sbo
=
False
,
):
if
self
.
deprecate_flag
:
assert
forward_shared_experts
is
None
assert
alt_stream
is
None
return
super
().
forward
(
hidden_states
,
topk_output
,
)
# We have to call SBO inside MoE to be compatible with hooks used in offloading
return
single_batch_overlap
.
execute_sbo
(
hidden_states
=
hidden_states
,
...
...
@@ -178,37 +192,51 @@ class DeepEPMoE(FusedMoE):
dispatch_output
:
DispatchOutput
,
down_gemm_overlap_args
:
Optional
[
DownGemmOverlapArgs
]
=
None
,
):
if
self
.
deprecate_flag
:
assert
down_gemm_overlap_args
is
None
return
super
().
run_moe_core
(
dispatch_output
,
)
from
sglang.srt.layers.moe.token_dispatcher
import
DispatchOutputChecker
if
_use_aiter
:
assert
DispatchOutputChecker
.
format_is_deepep
(
dispatch_output
)
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return
self
.
forward_aiter
(
dispatch_output
)
if
_is_npu
:
output
=
self
.
forward_aiter
(
dispatch_output
)
el
if
_is_npu
:
assert
DispatchOutputChecker
.
format_is_deepep
(
dispatch_output
)
return
self
.
forward_npu
(
dispatch_output
)
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
output
=
self
.
forward_npu
(
dispatch_output
)
el
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
):
if
self
.
use_w4afp8
:
return
self
.
forward_cutlass_w4afp8
(
dispatch_output
)
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
return
self
.
forward_deepgemm_contiguous
(
dispatch_output
)
output
=
self
.
forward_cutlass_w4afp8
(
dispatch_output
)
else
:
assert
False
,
"
forward_deepgemm_contiguous
is deprecated"
elif
DispatchOutputChecker
.
format_is_deepep_ll
(
dispatch_output
):
if
(
get_moe_runner_backend
().
is_flashinfer_cutedsl
()
and
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
):
return
self
.
forward_flashinfer_cutedsl
(
output
=
self
.
forward_flashinfer_cutedsl
(
dispatch_output
,
down_gemm_overlap_args
=
down_gemm_overlap_args
)
elif
self
.
use_w4afp8
:
return
self
.
forward_cutlass_w4afp8_masked
(
dispatch_output
)
assert
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
and
self
.
use_fp8_w8a8
assert
down_gemm_overlap_args
is
None
return
self
.
forward_deepgemm_masked
(
dispatch_output
)
else
:
raise
ValueError
(
f
"Dispatch output format
{
dispatch_output
.
format
}
is not supported"
)
output
=
self
.
forward_cutlass_w4afp8_masked
(
dispatch_output
)
else
:
assert
False
,
"forward_deepgemm_masked is deprecated"
combine_input_wrapper
=
(
DeepEPNormalCombineInput
if
DispatchOutputChecker
.
format_is_deepep_normal
(
dispatch_output
)
else
DeepEPLLCombineInput
)
return
combine_input_wrapper
(
hidden_states
=
output
,
topk_ids
=
dispatch_output
.
topk_ids
,
topk_weights
=
dispatch_output
.
topk_weights
,
overlap_args
=
down_gemm_overlap_args
,
)
def
combine
(
self
,
...
...
@@ -226,7 +254,7 @@ class DeepEPMoE(FusedMoE):
def
forward_aiter
(
self
,
dispatch_output
:
Union
[
DeepEPNormalOutput
,
DeepEPLLOutput
],
dispatch_output
:
Union
[
DeepEPNormal
Dispatch
Output
,
DeepEPLL
Dispatch
Output
],
):
hidden_states
,
topk_ids
,
topk_weights
=
(
dispatch_output
.
hidden_states
,
...
...
@@ -258,158 +286,9 @@ class DeepEPMoE(FusedMoE):
expert_mask
=
self
.
expert_mask
,
)
def
forward_deepgemm_contiguous
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
):
(
hidden_states
,
hidden_states_scale
,
topk_ids
,
topk_weights
,
num_recv_tokens_per_expert
,
)
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
if
num_recv_tokens_per_expert
is
None
:
return
hidden_states
.
bfloat16
()
all_tokens
=
sum
(
num_recv_tokens_per_expert
)
if
all_tokens
<=
0
:
return
hidden_states
.
bfloat16
()
M
,
K
=
hidden_states
.
size
()
N
=
self
.
w13_weight
.
size
(
1
)
scale_block_size
=
128
w13_weight_fp8
=
(
self
.
w13_weight
,
(
self
.
w13_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w13_weight_scale
),
)
w2_weight_fp8
=
(
self
.
w2_weight
,
(
self
.
w2_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w2_weight_scale
),
)
hidden_states_shape
=
hidden_states
.
shape
hidden_states_device
=
hidden_states
.
device
hidden_states_dtype
=
hidden_states
.
dtype
input_tensor
=
[
torch
.
empty
(
(
all_tokens
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
),
(
# TODO check whether need `zeros`
torch
.
zeros
(
(
ceil_div
(
K
//
128
,
4
),
all_tokens
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int
,
).
transpose
(
0
,
1
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
torch
.
empty
(
(
all_tokens
,
K
//
128
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
,
)
),
]
m_indices
=
torch
.
empty
(
all_tokens
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int32
)
output_index
=
torch
.
empty_like
(
topk_ids
)
if
get_offloader
().
forbid_copy_engine_usage
:
num_recv_tokens_per_expert_gpu
=
copy_list_to_gpu_no_ce
(
num_recv_tokens_per_expert
)
else
:
num_recv_tokens_per_expert_gpu
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int32
,
pin_memory
=
True
,
device
=
"cpu"
,
).
cuda
(
non_blocking
=
True
)
expert_start_loc
=
torch
.
empty_like
(
num_recv_tokens_per_expert_gpu
)
ep_scatter
(
hidden_states
,
hidden_states_scale
,
topk_ids
,
num_recv_tokens_per_expert_gpu
,
expert_start_loc
,
input_tensor
[
0
],
input_tensor
[
1
],
m_indices
,
output_index
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
dispose_tensor
(
hidden_states
)
gateup_output
=
torch
.
empty
(
(
all_tokens
,
N
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
)
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
input_tensor
,
w13_weight_fp8
,
gateup_output
,
m_indices
)
del
input_tensor
down_input
=
torch
.
empty
(
(
all_tokens
,
N
//
2
,
),
device
=
gateup_output
.
device
,
dtype
=
torch
.
bfloat16
,
)
silu_and_mul
(
gateup_output
.
view
(
-
1
,
N
),
down_input
)
del
gateup_output
down_output
=
torch
.
empty
(
(
all_tokens
,
K
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
)
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
down_input
,
scale_block_size
,
column_major_scales
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_tma_aligned
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
del
down_input
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
(
down_input_fp8
,
down_input_scale
),
w2_weight_fp8
,
down_output
,
m_indices
,
)
del
down_input_fp8
,
down_input_scale
gather_out
=
torch
.
empty
(
hidden_states_shape
,
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
)
ep_gather
(
down_output
,
topk_ids
,
topk_weights
,
output_index
,
gather_out
)
return
gather_out
def
forward_flashinfer_cutedsl
(
self
,
dispatch_output
:
DeepEPLLOutput
,
dispatch_output
:
DeepEPLL
Dispatch
Output
,
down_gemm_overlap_args
:
Optional
[
DownGemmOverlapArgs
],
):
hidden_states
,
hidden_states_scale
,
_
,
_
,
masked_m
,
_
=
dispatch_output
...
...
@@ -427,7 +306,7 @@ class DeepEPMoE(FusedMoE):
def
forward_cutlass_w4afp8
(
self
,
dispatch_output
:
DeepEPNormalOutput
,
dispatch_output
:
DeepEPNormal
Dispatch
Output
,
):
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
isinstance
(
self
.
quant_method
,
W4AFp8MoEMethod
)
...
...
@@ -436,90 +315,9 @@ class DeepEPMoE(FusedMoE):
dispatch_output
=
dispatch_output
,
)
def
forward_deepgemm_masked
(
self
,
dispatch_output
:
DeepEPLLOutput
,
):
hidden_states
,
hidden_states_scale
,
_
,
_
,
masked_m
,
expected_m
=
dispatch_output
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
hidden_states_scale
.
dtype
==
torch
.
float32
or
(
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
and
hidden_states_scale
.
dtype
==
torch
.
int32
),
f
"hidden_states_scale.dtype:
{
hidden_states_scale
.
dtype
}
, DEEPGEMM_SCALE_UE8M0:
{
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
}
"
# GroupGemm-0
num_groups
,
m
,
k
=
hidden_states
.
size
()
n
=
self
.
w13_weight
.
size
(
1
)
expected_m
=
min
(
expected_m
,
m
)
gateup_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
(
hidden_states
,
hidden_states_scale
),
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
,
)
dispose_tensor
(
hidden_states
)
# Act
down_input
=
torch
.
empty
(
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
],
gateup_output
.
shape
[
2
]
//
2
,
),
device
=
gateup_output
.
device
,
dtype
=
self
.
fp8_dtype
,
)
scale_block_size
=
128
down_input_scale
=
torch
.
empty
(
(
gateup_output
.
shape
[
0
],
gateup_output
.
shape
[
1
],
gateup_output
.
shape
[
2
]
//
2
//
scale_block_size
,
),
device
=
gateup_output
.
device
,
dtype
=
torch
.
float32
,
)
silu_and_mul_masked_post_quant_fwd
(
gateup_output
,
down_input
,
down_input_scale
,
scale_block_size
,
masked_m
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
del
gateup_output
# GroupGemm-1
n
=
self
.
w2_weight
.
size
(
1
)
down_input_fp8
=
(
down_input
,
(
down_input_scale
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
deep_gemm_wrapper
.
get_mn_major_tma_aligned_tensor
(
down_input_scale
)
),
)
down_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
down_input
.
device
,
dtype
=
torch
.
bfloat16
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
down_input_fp8
,
self
.
w2_weight_fp8
,
down_output
,
masked_m
,
expected_m
,
)
return
down_output
def
forward_cutlass_w4afp8_masked
(
self
,
dispatch_output
:
DeepEP
Normal
Output
,
dispatch_output
:
DeepEP
LLDispatch
Output
,
):
assert
self
.
moe_runner_config
.
activation
==
"silu"
assert
isinstance
(
self
.
quant_method
,
W4AFp8MoEMethod
)
...
...
@@ -533,7 +331,7 @@ class DeepEPMoE(FusedMoE):
def
forward_npu
(
self
,
dispatch_output
:
Union
[
DeepEPNormalOutput
,
DeepEPLLOutput
],
dispatch_output
:
Union
[
DeepEPNormal
Dispatch
Output
,
DeepEPLL
Dispatch
Output
],
):
assert
self
.
quant_method
is
not
None
assert
self
.
moe_runner_config
.
activation
==
"silu"
...
...
@@ -546,9 +344,9 @@ class DeepEPMoE(FusedMoE):
output_dtype
=
torch
.
bfloat16
group_list_type
=
1
def
_forward_normal
(
dispatch_output
:
DeepEPNormalOutput
):
def
_forward_normal
(
dispatch_output
:
DeepEPNormal
Dispatch
Output
):
if
TYPE_CHECKING
:
assert
isinstance
(
dispatch_output
,
DeepEPNormalOutput
)
assert
isinstance
(
dispatch_output
,
DeepEPNormal
Dispatch
Output
)
hidden_states
,
hidden_states_scale
,
_
,
_
,
num_recv_tokens_per_expert
=
(
dispatch_output
)
...
...
@@ -618,9 +416,9 @@ class DeepEPMoE(FusedMoE):
return
hidden_states
def
_forward_ll
(
dispatch_output
:
DeepEPLLOutput
):
def
_forward_ll
(
dispatch_output
:
DeepEPLL
Dispatch
Output
):
if
TYPE_CHECKING
:
assert
isinstance
(
dispatch_output
,
DeepEPLLOutput
)
assert
isinstance
(
dispatch_output
,
DeepEPLL
Dispatch
Output
)
(
hidden_states
,
hidden_states_scale
,
...
...
@@ -731,12 +529,3 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if
get_moe_runner_backend
().
is_flashinfer_cutlass
():
return
FusedMoE
return
FusedMoE
def
copy_list_to_gpu_no_ce
(
arr
:
List
[
int
]):
from
sgl_kernel.elementwise
import
copy_to_gpu_no_ce
tensor_cpu
=
torch
.
tensor
(
arr
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
tensor_gpu
=
torch
.
empty_like
(
tensor_cpu
,
device
=
"cuda"
)
copy_to_gpu_no_ce
(
tensor_cpu
,
tensor_gpu
)
return
tensor_gpu
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
64994980
...
...
@@ -839,7 +839,7 @@ class FusedMoE(torch.nn.Module):
dispatch_output
=
dispatch_output
,
**
kwargs
,
)
final_hidden_states
=
self
.
dispatcher
.
combine
(
combine_input
)
final_hidden_states
=
self
.
dispatcher
.
combine
(
combine_input
=
combine_input
)
# TODO: should we add some conditions here?
final_hidden_states
=
final_hidden_states
[
...
...
python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
View file @
64994980
...
...
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional
import
torch
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.moe.moe_runner.base
import
(
MoeQuantInfo
,
MoeRunnerConfig
,
...
...
@@ -15,14 +16,28 @@ from sglang.srt.layers.moe.moe_runner.base import (
register_pre_permute
,
)
from
sglang.srt.layers.moe.utils
import
MoeRunnerBackend
from
sglang.srt.utils
import
dispose_tensor
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
from
sglang.srt.utils.offloader
import
get_offloader
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher.deepep
import
(
DeepEPLLCombineInput
,
DeepEPLLDispatchOutput
,
DeepEPNormalCombineInput
,
DeepEPNormalDispatchOutput
,
)
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
StandardCombineInput
,
StandardDispatchOutput
,
)
_is_hip
=
is_hip
()
_is_npu
=
is_npu
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
not
(
_is_npu
or
_is_hip
):
from
sgl_kernel
import
silu_and_mul
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
...
...
@@ -40,13 +55,23 @@ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
return
new_x
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
def
copy_list_to_gpu_no_ce
(
arr
:
List
[
int
]):
from
sgl_kernel.elementwise
import
copy_to_gpu_no_ce
tensor_cpu
=
torch
.
tensor
(
arr
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
tensor_gpu
=
torch
.
empty_like
(
tensor_cpu
,
device
=
"cuda"
)
copy_to_gpu_no_ce
(
tensor_cpu
,
tensor_gpu
)
return
tensor_gpu
@
dataclass
class
DeepGemmRunnerInput
(
RunnerInput
):
hidden_states
:
torch
.
Tensor
hidden_states_scale
:
torch
.
Tensor
masked_m
:
torch
.
Tensor
expected_m
:
int
use_masked_gemm
:
bool
masked_m
:
Optional
[
torch
.
Tensor
]
=
None
expected_m
:
Optional
[
int
]
=
None
m_indices
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
runner_backend
(
self
)
->
MoeRunnerBackend
:
...
...
@@ -84,20 +109,100 @@ class DeepGemmRunnerCore(MoeRunnerCore):
running_state
:
dict
,
)
->
DeepGemmRunnerOutput
:
if
runner_input
.
use_masked_gemm
:
hidden_states
=
self
.
_run_masked_gemm
(
runner_input
,
quant_info
,
running_state
,
if
not
runner_input
.
use_masked_gemm
:
hidden_states
=
self
.
_run_contiguous_gemm
(
runner_input
,
quant_info
,
running_state
)
else
:
hidden_states
=
self
.
_run_contiguous_gemm
(
runner_input
,
quant_info
,
running_state
,
hidden_states
=
self
.
_run_masked_gemm
(
runner_input
,
quant_info
,
running_state
)
return
DeepGemmRunnerOutput
(
hidden_states
=
hidden_states
)
def
_run_contiguous_gemm
(
self
,
runner_input
:
DeepGemmRunnerInput
,
quant_info
:
DeepGemmMoeQuantInfo
,
running_state
:
dict
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.ep_moe.kernels
import
tma_align_input_scale
from
sglang.srt.layers.quantization.fp8_kernel
import
(
sglang_per_token_group_quant_fp8
,
)
hidden_states
=
runner_input
.
hidden_states
hidden_states_scale
=
runner_input
.
hidden_states_scale
all_tokens
=
running_state
[
"all_tokens"
]
hidden_states_device
=
running_state
[
"hidden_states_device"
]
hidden_states_dtype
=
running_state
[
"hidden_states_dtype"
]
hidden_states_shape
=
running_state
[
"hidden_states_shape"
]
m_indices
=
runner_input
.
m_indices
N
=
quant_info
.
w13_weight
.
size
(
1
)
K
=
hidden_states_shape
[
1
]
scale_block_size
=
128
w13_weight_fp8
=
(
quant_info
.
w13_weight
,
quant_info
.
w13_scale
,
)
w2_weight_fp8
=
(
quant_info
.
w2_weight
,
quant_info
.
w2_scale
)
gateup_output
=
torch
.
empty
(
(
all_tokens
,
N
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
)
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
hidden_states_scale
=
tma_align_input_scale
(
hidden_states_scale
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
(
hidden_states
,
hidden_states_scale
),
w13_weight_fp8
,
gateup_output
,
m_indices
,
)
dispose_tensor
(
hidden_states
)
dispose_tensor
(
hidden_states_scale
)
down_input
=
torch
.
empty
(
(
all_tokens
,
N
//
2
,
),
device
=
gateup_output
.
device
,
dtype
=
torch
.
bfloat16
,
)
silu_and_mul
(
gateup_output
.
view
(
-
1
,
N
),
down_input
)
del
gateup_output
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
down_input
,
scale_block_size
,
column_major_scales
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_tma_aligned
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
del
down_input
down_output
=
torch
.
empty
(
(
all_tokens
,
K
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
,
)
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
(
down_input_fp8
,
down_input_scale
),
w2_weight_fp8
,
down_output
,
m_indices
,
)
return
down_output
def
_run_masked_gemm
(
self
,
runner_input
:
DeepGemmRunnerInput
,
...
...
@@ -149,6 +254,7 @@ class DeepGemmRunnerCore(MoeRunnerCore):
expected_m
,
)
dispose_tensor
(
hidden_states
)
dispose_tensor
(
hidden_states_scale
)
# Act
down_input
=
torch
.
empty
(
...
...
@@ -198,18 +304,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
masked_m
,
expected_m
,
)
del
down_input
return
down_output
def
_run_contiguous_gemm
(
self
,
runner_input
:
DeepGemmRunnerInput
,
quant_info
:
DeepGemmMoeQuantInfo
,
running_state
:
dict
,
)
->
torch
.
Tensor
:
pass
@
property
def
runner_backend
(
self
)
->
MoeRunnerBackend
:
return
MoeRunnerBackend
.
DEEP_GEMM
...
...
@@ -222,6 +319,7 @@ def pre_permute_standard_to_deep_gemm(
runner_config
:
MoeRunnerConfig
,
running_state
:
dict
,
)
->
DeepGemmRunnerInput
:
from
sglang.srt.layers.moe.ep_moe.kernels
import
moe_ep_deepgemm_preprocess
hidden_states
,
topk_output
=
dispatch_output
...
...
@@ -257,9 +355,9 @@ def pre_permute_standard_to_deep_gemm(
return
DeepGemmRunnerInput
(
hidden_states
=
hidden_states
,
hidden_states_scale
=
hidden_states_scale
,
use_masked_gemm
=
True
,
masked_m
=
masked_m
,
expected_m
=
expected_m
,
use_masked_gemm
=
True
,
)
...
...
@@ -302,3 +400,170 @@ def post_permute_deep_gemm_to_standard(
return
StandardCombineInput
(
hidden_states
=
output
,
)
@
register_pre_permute
(
"deepep_ll"
,
"deep_gemm"
)
def
pre_permute_deepep_ll_to_deep_gemm
(
dispatch_output
:
DeepEPLLDispatchOutput
,
quant_info
:
DeepGemmMoeQuantInfo
,
runner_config
:
MoeRunnerConfig
,
running_state
:
dict
,
)
->
DeepGemmRunnerInput
:
hidden_states
,
hidden_states_scale
,
topk_ids
,
topk_weights
,
masked_m
,
expected_m
=
(
dispatch_output
)
running_state
[
"topk_ids"
]
=
topk_ids
running_state
[
"topk_weights"
]
=
topk_weights
running_state
[
"hidden_states_shape"
]
=
hidden_states
.
shape
running_state
[
"hidden_states_dtype"
]
=
hidden_states
.
dtype
running_state
[
"hidden_states_device"
]
=
hidden_states
.
device
return
DeepGemmRunnerInput
(
hidden_states
=
hidden_states
,
hidden_states_scale
=
hidden_states_scale
,
use_masked_gemm
=
True
,
masked_m
=
masked_m
,
expected_m
=
expected_m
,
)
@
register_post_permute
(
"deep_gemm"
,
"deepep_ll"
)
def
post_permute_deep_gemm_to_deepep_ll
(
runner_output
:
DeepGemmRunnerOutput
,
quant_info
:
DeepGemmMoeQuantInfo
,
runner_config
:
MoeRunnerConfig
,
running_state
:
dict
,
)
->
DeepEPLLCombineInput
:
from
sglang.srt.layers.moe.token_dispatcher.deepep
import
DeepEPLLCombineInput
return
DeepEPLLCombineInput
(
hidden_states
=
runner_output
.
hidden_states
,
topk_ids
=
running_state
[
"topk_ids"
],
topk_weights
=
running_state
[
"topk_weights"
],
)
@
register_pre_permute
(
"deepep_normal"
,
"deep_gemm"
)
def
pre_permute_deepep_normal_to_deep_gemm
(
dispatch_output
:
DeepEPNormalDispatchOutput
,
quant_info
:
DeepGemmMoeQuantInfo
,
runner_config
:
MoeRunnerConfig
,
running_state
:
dict
,
)
->
DeepGemmRunnerInput
:
from
sglang.srt.layers.moe.ep_moe.kernels
import
ep_scatter
(
hidden_states
,
hidden_states_scale
,
topk_ids
,
topk_weights
,
num_recv_tokens_per_expert
,
)
=
dispatch_output
assert
runner_config
.
activation
==
"silu"
all_tokens
=
sum
(
num_recv_tokens_per_expert
)
running_state
[
"all_tokens"
]
=
all_tokens
K
=
hidden_states
.
shape
[
1
]
hidden_states_shape
=
hidden_states
.
shape
hidden_states_device
=
hidden_states
.
device
hidden_states_dtype
=
hidden_states
.
dtype
running_state
[
"hidden_states_shape"
]
=
hidden_states_shape
running_state
[
"hidden_states_device"
]
=
hidden_states_device
running_state
[
"hidden_states_dtype"
]
=
hidden_states_dtype
running_state
[
"topk_ids"
]
=
topk_ids
running_state
[
"topk_weights"
]
=
topk_weights
input_tensor
=
torch
.
empty
(
(
all_tokens
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
# TODO check whether need `zeros`
input_tensor_scale
=
torch
.
zeros
(
(
ceil_div
(
K
//
128
,
4
),
all_tokens
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
int
,
).
transpose
(
0
,
1
)
else
:
input_tensor_scale
=
torch
.
empty
(
(
all_tokens
,
K
//
128
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
,
)
m_indices
=
torch
.
empty
(
all_tokens
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int32
)
output_index
=
torch
.
empty_like
(
topk_ids
)
if
get_offloader
().
forbid_copy_engine_usage
:
num_recv_tokens_per_expert_gpu
=
copy_list_to_gpu_no_ce
(
num_recv_tokens_per_expert
)
else
:
num_recv_tokens_per_expert_gpu
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int32
,
pin_memory
=
True
,
device
=
"cpu"
,
).
cuda
(
non_blocking
=
True
)
expert_start_loc
=
torch
.
empty_like
(
num_recv_tokens_per_expert_gpu
)
ep_scatter
(
hidden_states
,
hidden_states_scale
,
topk_ids
,
num_recv_tokens_per_expert_gpu
,
expert_start_loc
,
input_tensor
,
input_tensor_scale
,
m_indices
,
output_index
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
dispose_tensor
(
hidden_states
)
dispose_tensor
(
hidden_states_scale
)
running_state
[
"output_index"
]
=
output_index
return
DeepGemmRunnerInput
(
hidden_states
=
input_tensor
,
hidden_states_scale
=
input_tensor_scale
,
use_masked_gemm
=
False
,
m_indices
=
m_indices
,
)
@
register_post_permute
(
"deep_gemm"
,
"deepep_normal"
)
def
post_permute_deep_gemm_to_deepep_normal
(
runner_output
:
DeepGemmRunnerOutput
,
quant_info
:
DeepGemmMoeQuantInfo
,
runner_config
:
MoeRunnerConfig
,
running_state
:
dict
,
)
->
DeepEPNormalCombineInput
:
from
sglang.srt.layers.moe.ep_moe.kernels
import
ep_gather
from
sglang.srt.layers.moe.token_dispatcher.deepep
import
DeepEPNormalCombineInput
hidden_states
=
runner_output
.
hidden_states
topk_ids
=
running_state
[
"topk_ids"
]
topk_weights
=
running_state
[
"topk_weights"
]
output_index
=
running_state
[
"output_index"
]
gather_out
=
torch
.
empty
(
running_state
[
"hidden_states_shape"
],
device
=
running_state
[
"hidden_states_device"
],
dtype
=
torch
.
bfloat16
,
)
ep_gather
(
hidden_states
,
topk_ids
,
topk_weights
,
output_index
,
gather_out
)
return
DeepEPNormalCombineInput
(
hidden_states
=
gather_out
,
topk_ids
=
running_state
[
"topk_ids"
],
topk_weights
=
running_state
[
"topk_weights"
],
)
python/sglang/srt/layers/moe/token_dispatcher/__init__.py
View file @
64994980
...
...
@@ -12,9 +12,9 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPConfig
,
DeepEPDispatcher
,
DeepEPLLCombineInput
,
DeepEPLLOutput
,
DeepEPLL
Dispatch
Output
,
DeepEPNormalCombineInput
,
DeepEPNormalOutput
,
DeepEPNormal
Dispatch
Output
,
)
from
sglang.srt.layers.moe.token_dispatcher.mooncake
import
(
MooncakeCombineInput
,
...
...
@@ -44,8 +44,8 @@ __all__ = [
"StandardCombineInput"
,
"DeepEPConfig"
,
"DeepEPDispatcher"
,
"DeepEPNormalOutput"
,
"DeepEPLLOutput"
,
"DeepEPNormal
Dispatch
Output"
,
"DeepEPLL
Dispatch
Output"
,
"DeepEPLLCombineInput"
,
"DeepEPNormalCombineInput"
,
]
python/sglang/srt/layers/moe/token_dispatcher/base.py
View file @
64994980
...
...
@@ -9,9 +9,9 @@ import torch
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher
import
(
DeepEPLLCombineInput
,
DeepEPLLOutput
,
DeepEPLL
Dispatch
Output
,
DeepEPNormalCombineInput
,
DeepEPNormalOutput
,
DeepEPNormal
Dispatch
Output
,
StandardCombineInput
,
StandardDispatchOutput
,
)
...
...
@@ -37,19 +37,19 @@ class DispatchOutputChecker:
@
staticmethod
def
format_is_deepep_normal
(
dispatch_output
:
DispatchOutput
,
)
->
TypeGuard
[
DeepEPNormalOutput
]:
)
->
TypeGuard
[
DeepEPNormal
Dispatch
Output
]:
return
dispatch_output
.
format
.
is_deepep_normal
()
@
staticmethod
def
format_is_deepep_ll
(
dispatch_output
:
DispatchOutput
,
)
->
TypeGuard
[
DeepEPLLOutput
]:
)
->
TypeGuard
[
DeepEPLL
Dispatch
Output
]:
return
dispatch_output
.
format
.
is_deepep_ll
()
@
staticmethod
def
format_is_deepep
(
dispatch_output
:
DispatchOutput
,
)
->
TypeGuard
[
Union
[
DeepEPNormalOutput
,
DeepEPLLOutput
]]:
)
->
TypeGuard
[
Union
[
DeepEPNormal
Dispatch
Output
,
DeepEPLL
Dispatch
Output
]]:
return
dispatch_output
.
format
.
is_deepep
()
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
64994980
...
...
@@ -58,7 +58,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
logger
=
logging
.
getLogger
(
__name__
)
class
DeepEPNormalOutput
(
NamedTuple
):
class
DeepEPNormal
Dispatch
Output
(
NamedTuple
):
"""DeepEP normal dispatch output."""
hidden_states
:
torch
.
Tensor
...
...
@@ -72,7 +72,7 @@ class DeepEPNormalOutput(NamedTuple):
return
DispatchOutputFormat
.
DEEPEP_NORMAL
class
DeepEPLLOutput
(
NamedTuple
):
class
DeepEPLL
Dispatch
Output
(
NamedTuple
):
"""DeepEP low latency dispatch output."""
hidden_states
:
torch
.
Tensor
...
...
@@ -87,14 +87,17 @@ class DeepEPLLOutput(NamedTuple):
return
DispatchOutputFormat
.
DEEPEP_LL
assert
isinstance
(
DeepEPNormalOutput
,
DispatchOutput
)
assert
isinstance
(
DeepEPLLOutput
,
DispatchOutput
)
assert
isinstance
(
DeepEPNormal
Dispatch
Output
,
DispatchOutput
)
assert
isinstance
(
DeepEPLL
Dispatch
Output
,
DispatchOutput
)
class
DeepEPNormalCombineInput
(
NamedTuple
):
"""DeepEP normal combine input."""
pass
hidden_states
:
torch
.
Tensor
topk_ids
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
@
property
def
format
(
self
)
->
CombineInputFormat
:
...
...
@@ -104,7 +107,10 @@ class DeepEPNormalCombineInput(NamedTuple):
class
DeepEPLLCombineInput
(
NamedTuple
):
"""DeepEP low latency combine input."""
pass
hidden_states
:
torch
.
Tensor
topk_ids
:
torch
.
Tensor
topk_weights
:
torch
.
Tensor
overlap_args
:
Optional
[
CombineOverlapArgs
]
=
None
@
property
def
format
(
self
)
->
CombineInputFormat
:
...
...
@@ -383,7 +389,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
else
:
hidden_states_scale
=
None
return
DeepEPNormalOutput
(
return
DeepEPNormal
Dispatch
Output
(
hidden_states
,
hidden_states_scale
,
topk_ids
,
...
...
@@ -562,7 +568,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
else
:
hidden_states_scale
=
None
deepep_output
=
DeepEPLLOutput
(
deepep_output
=
DeepEPLL
Dispatch
Output
(
hidden_states
,
hidden_states_scale
,
topk_ids
,
...
...
@@ -756,18 +762,16 @@ class DeepEPDispatcher(BaseDispatcher):
del
self
.
_dispatch_intermediate_state
return
self
.
_get_impl
().
dispatch_b
(
*
inner_state
)
def
combine
(
self
,
*
args
,
**
kwargs
)
->
Tuple
:
self
.
combine_a
(
*
args
,
**
kwargs
)
def
combine
(
self
,
combine_input
:
CombineInput
)
->
Tuple
:
self
.
combine_a
(
combine_input
)
ret
=
self
.
combine_b
()
return
ret
def
combine_a
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
overlap_args
:
Optional
[
"CombineOverlapArgs"
]
=
None
,
combine_input
:
CombineInput
,
):
hidden_states
,
topk_ids
,
topk_weights
,
overlap_args
=
combine_input
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_B
,
_Stage
.
AFTER_COMBINE_A
)
inner_state
=
self
.
_get_impl
().
combine_a
(
hidden_states
=
hidden_states
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
64994980
...
...
@@ -984,13 +984,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
moe_runner_config
=
self
.
moe_runner_config
if
use_intel_amx_backend
(
layer
):
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
dispatch_output
.
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
)
...
...
@@ -1017,7 +1016,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
ret
=
self
.
maybe_apply_hip_fused_experts
(
layer
,
x
,
topk_output
,
dispatch_output
.
topk_output
,
moe_runner_config
.
activation
,
moe_runner_config
.
no_combine
,
)
...
...
@@ -1027,7 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
_should_use_cutlass_fused_experts
():
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
dispatch_output
.
topk_output
output
=
cutlass_fused_experts_fp8
(
x
,
layer
.
w13_weight
.
transpose
(
1
,
2
),
...
...
python/sglang/srt/layers/quantization/w4afp8.py
View file @
64994980
...
...
@@ -23,8 +23,8 @@ if TYPE_CHECKING:
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
from
sglang.srt.layers.moe.token_dispatcher
import
(
CombineInput
,
DeepEPLLOutput
,
DeepEPNormalOutput
,
DeepEPLL
Dispatch
Output
,
DeepEPNormal
Dispatch
Output
,
StandardDispatchOutput
,
)
...
...
@@ -332,7 +332,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def
apply_deepep_ll
(
self
,
layer
:
DeepEPMoE
,
dispatch_output
:
DeepEPLLOutput
,
dispatch_output
:
DeepEPLL
Dispatch
Output
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
cutlass_w4a8_moe_deepep_ll
...
...
@@ -367,7 +367,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def
apply_deepep_normal
(
self
,
layer
:
DeepEPMoE
,
dispatch_output
:
DeepEPNormalOutput
,
dispatch_output
:
DeepEPNormal
Dispatch
Output
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
(
cutlass_w4a8_moe_deepep_normal
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
64994980
...
...
@@ -1005,16 +1005,14 @@ class DeepseekV2MoE(nn.Module):
)
def
op_experts
(
self
,
state
):
state
.
hidden_states_experts_out
put
=
self
.
experts
.
run_moe_core
(
state
.
combine_in
put
=
self
.
experts
.
run_moe_core
(
dispatch_output
=
state
.
dispatch_output
,
)
def
op_combine_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
self
.
experts
.
dispatcher
.
combine_a
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_ids
=
state
.
dispatch_output
.
topk_ids
,
topk_weights
=
state
.
dispatch_output
.
topk_weights
,
combine_input
=
state
.
pop
(
"combine_input"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
state
.
pop
(
"dispatch_output"
)
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
64994980
...
...
@@ -241,16 +241,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
)
def
op_experts
(
self
,
state
):
state
.
hidden_states_experts_out
put
=
self
.
experts
.
run_moe_core
(
state
.
combine_in
put
=
self
.
experts
.
run_moe_core
(
dispatch_output
=
state
.
dispatch_output
,
)
def
op_combine_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
self
.
experts
.
dispatcher
.
combine_a
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_ids
=
state
.
dispatch_output
.
topk_ids
,
topk_weights
=
state
.
dispatch_output
.
topk_weights
,
combine_input
=
state
.
pop
(
"combine_input"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
state
.
pop
(
"dispatch_output"
)
...
...
python/sglang/srt/single_batch_overlap.py
View file @
64994980
...
...
@@ -85,7 +85,7 @@ def execute_sbo(
_compute_overlap_args
(
dispatch_output
,
alt_stream
,
disable_sbo
=
disable_sbo
)
)
hidden_states
=
experts
.
run_moe_core
(
combine_input
=
experts
.
run_moe_core
(
dispatch_output
,
down_gemm_overlap_args
=
down_gemm_overlap_args
)
if
(
e
:
=
meta_overlap_args
.
get
(
"record_event_after_down"
))
is
not
None
:
...
...
@@ -98,12 +98,7 @@ def execute_sbo(
):
forward_shared_experts
()
hidden_states
=
experts
.
dispatcher
.
combine
(
hidden_states
=
hidden_states
,
topk_ids
=
dispatch_output
.
topk_ids
,
topk_weights
=
dispatch_output
.
topk_weights
,
overlap_args
=
combine_overlap_args
,
)
hidden_states
=
experts
.
dispatcher
.
combine
(
combine_input
=
combine_input
)
return
hidden_states
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment