Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
32fa1e9c
Unverified
Commit
32fa1e9c
authored
Jul 31, 2025
by
Cheng Wan
Committed by
GitHub
Jul 31, 2025
Browse files
[4/N] MoE Refactor: Unified Triton Kernel for FusedMoE and EPMoE (#8515)
parent
e7dc163f
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
70 additions
and
690 deletions
+70
-690
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+15
-648
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+22
-4
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+32
-12
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+0
-18
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+0
-8
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+1
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
32fa1e9c
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
32fa1e9c
...
@@ -413,18 +413,37 @@ def fused_moe_kernel(
...
@@ -413,18 +413,37 @@ def fused_moe_kernel(
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
.
to
(
tl
.
int64
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
offs_token
=
offs_token
.
to
(
tl
.
int64
)
offs_token
=
offs_token
.
to
(
tl
.
int64
)
token_mask
=
offs_token
<
num_valid_tokens
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
,
)
return
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
(
b_ptrs
=
(
b_ptr
b_ptr
+
off_experts
*
stride_be
+
off_experts
*
stride_be
...
@@ -497,7 +516,6 @@ def fused_moe_kernel(
...
@@ -497,7 +516,6 @@ def fused_moe_kernel(
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
else
:
else
:
# fix out of shared memory issue
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
else
:
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
32fa1e9c
...
@@ -12,7 +12,7 @@ from sglang.srt.distributed import (
...
@@ -12,7 +12,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
Standard
TopKOutput
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
...
@@ -79,7 +79,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -79,7 +79,6 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_flashinfer_cutlass_moe
:
Optional
[
bool
]
=
False
,
enable_flashinfer_cutlass_moe
:
Optional
[
bool
]
=
False
,
enable_ep_moe
:
Optional
[
bool
]
=
False
,
enable_ep_moe
:
Optional
[
bool
]
=
False
,
skip_quant
:
Optional
[
bool
]
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -95,7 +94,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -95,7 +94,8 @@ class FusedMoE(torch.nn.Module):
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
expert_map
=
None
self
.
expert_map_cpu
=
None
self
.
expert_map_gpu
=
None
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
logger
.
warning
(
"Disable flashinfer MoE when quantization config is None."
)
logger
.
warning
(
"Disable flashinfer MoE when quantization config is None."
)
...
@@ -104,20 +104,22 @@ class FusedMoE(torch.nn.Module):
...
@@ -104,20 +104,22 @@ class FusedMoE(torch.nn.Module):
self
.
enable_flashinfer_cutlass_moe
=
enable_flashinfer_cutlass_moe
self
.
enable_flashinfer_cutlass_moe
=
enable_flashinfer_cutlass_moe
if
enable_ep_moe
:
if
enable_ep_moe
:
# TODO(ch-wan): support shared experts fusion
self
.
ep_size
=
self
.
tp_size
self
.
ep_size
=
self
.
tp_size
self
.
ep_rank
=
self
.
tp_rank
self
.
ep_rank
=
self
.
tp_rank
self
.
tp_size
=
1
self
.
tp_size
=
1
self
.
tp_rank
=
0
self
.
tp_rank
=
0
# Create a tensor of size num_experts filled with -1
# Create a tensor of size num_experts filled with -1
self
.
expert_map
=
torch
.
full
((
self
.
num_experts
,),
-
1
,
dtype
=
torch
.
int32
)
self
.
expert_map
_cpu
=
torch
.
full
((
self
.
num_experts
,),
-
1
,
dtype
=
torch
.
int32
)
# Create a expert map for the local experts
# Create a expert map for the local experts
assert
num_experts
%
self
.
ep_size
==
0
assert
num_experts
%
self
.
ep_size
==
0
self
.
num_local_experts
=
num_experts
//
self
.
ep_size
self
.
num_local_experts
=
num_experts
//
self
.
ep_size
self
.
expert_map
[
self
.
expert_map
_cpu
[
self
.
ep_rank
self
.
ep_rank
*
self
.
num_local_experts
:
(
self
.
ep_rank
+
1
)
*
self
.
num_local_experts
:
(
self
.
ep_rank
+
1
)
*
self
.
num_local_experts
*
self
.
num_local_experts
]
=
torch
.
arange
(
0
,
self
.
num_local_experts
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
]
=
torch
.
arange
(
0
,
self
.
num_local_experts
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
expert_map_gpu
=
self
.
expert_map_cpu
.
to
(
device
=
"cuda"
)
else
:
else
:
self
.
ep_size
=
1
self
.
ep_size
=
1
self
.
ep_rank
=
0
self
.
ep_rank
=
0
...
@@ -136,9 +138,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -136,9 +138,6 @@ class FusedMoE(torch.nn.Module):
not
_is_cpu
and
global_server_args_dict
[
"enable_triton_kernel_moe"
]
not
_is_cpu
and
global_server_args_dict
[
"enable_triton_kernel_moe"
]
)
)
if
skip_quant
:
return
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
use_triton_kernels
self
.
use_triton_kernels
...
@@ -367,9 +366,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -367,9 +366,9 @@ class FusedMoE(torch.nn.Module):
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
def
_map_global_expert_id_to_local_expert_id
(
self
,
expert_id
:
int
)
->
int
:
def
_map_global_expert_id_to_local_expert_id
(
self
,
expert_id
:
int
)
->
int
:
if
self
.
expert_map
is
None
:
if
self
.
expert_map
_cpu
is
None
:
return
expert_id
return
expert_id
return
self
.
expert_map
[
expert_id
].
item
()
return
self
.
expert_map
_cpu
[
expert_id
].
item
()
def
weight_loader
(
def
weight_loader
(
self
,
self
,
...
@@ -421,7 +420,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -421,7 +420,6 @@ class FusedMoE(torch.nn.Module):
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
if
expert_id
==
-
1
:
if
expert_id
==
-
1
:
return
return
self
.
_weight_loader_impl
(
self
.
_weight_loader_impl
(
param
=
param
,
param
=
param
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
...
@@ -614,9 +612,14 @@ class FusedMoE(torch.nn.Module):
...
@@ -614,9 +612,14 @@ class FusedMoE(torch.nn.Module):
)
)
return
return
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
Standard
TopKOutput
):
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
if
self
.
expert_map_gpu
is
not
None
:
topk_output
=
topk_output
.
_replace
(
topk_ids
=
self
.
expert_map_gpu
[
topk_output
.
topk_ids
]
)
# Matrix multiply.
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
layer
=
self
,
...
@@ -670,3 +673,20 @@ class FusedMoE(torch.nn.Module):
...
@@ -670,3 +673,20 @@ class FusedMoE(torch.nn.Module):
(
"w3"
,
ckpt_up_proj_name
),
(
"w3"
,
ckpt_up_proj_name
),
]
]
]
]
@
classmethod
def
make_expert_input_scale_params_mapping
(
cls
,
num_experts
:
int
,
)
->
List
[
Tuple
[
str
,
str
,
int
,
str
]]:
# (param_name, weight_name, expert_id, shard_id)
return
[
(
"experts.w13_"
if
shard_id
in
[
"w1"
,
"w3"
]
else
"experts.w2_"
,
f
"experts.
{
expert_id
}
.
{
shard_id
}
."
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
in
[
"w1"
,
"w2"
,
"w3"
]
]
python/sglang/srt/layers/quantization/fp8.py
View file @
32fa1e9c
...
@@ -172,7 +172,6 @@ class Fp8Config(QuantizationConfig):
...
@@ -172,7 +172,6 @@ class Fp8Config(QuantizationConfig):
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
...
@@ -181,8 +180,6 @@ class Fp8Config(QuantizationConfig):
...
@@ -181,8 +180,6 @@ class Fp8Config(QuantizationConfig):
return
Fp8LinearMethod
(
self
)
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
return
Fp8MoEMethod
(
self
)
return
Fp8MoEMethod
(
self
)
elif
isinstance
(
layer
,
EPMoE
):
return
Fp8EPMoEMethod
(
self
)
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
@@ -984,23 +981,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -984,23 +981,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
if
isinstance
(
layer
,
EPMoE
):
layer
.
w13_weight_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
)
layer
.
w2_weight_scale
=
(
layer
.
w2_weight_scale_inv
if
self
.
block_quant
else
layer
.
w2_weight_scale
)
return
layer
.
run_moe
(
hidden_states
=
x
,
topk_output
=
topk_output
,
)
if
use_intel_amx_backend
(
layer
):
if
use_intel_amx_backend
(
layer
):
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
...
...
python/sglang/srt/layers/quantization/unquant.py
View file @
32fa1e9c
...
@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
if
isinstance
(
layer
,
EPMoE
):
return
layer
.
run_moe
(
hidden_states
=
x
,
topk_output
=
topk_output
,
)
return
self
.
forward
(
return
self
.
forward
(
x
=
x
,
x
=
x
,
layer
=
layer
,
layer
=
layer
,
...
...
python/sglang/srt/layers/quantization/w4afp8.py
View file @
32fa1e9c
...
@@ -276,6 +276,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
...
@@ -276,6 +276,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer
:
EPMoE
,
layer
:
EPMoE
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
topk_output
:
TopKOutput
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO(ch-wan): move it out of this class
# TODO(ch-wan): move it out of this class
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment