Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62eff37b
Unverified
Commit
62eff37b
authored
Oct 23, 2025
by
Jonah Bernard
Committed by
GitHub
Oct 23, 2025
Browse files
Refactor Triton-kernel MoE runner integration (#11795)
parent
47e12e08
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
325 additions
and
112 deletions
+325
-112
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+1
-1
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
...ang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
+7
-4
python/sglang/srt/layers/moe/moe_runner/runner.py
python/sglang/srt/layers/moe/moe_runner/runner.py
+3
-0
python/sglang/srt/layers/moe/moe_runner/triton_kernels.py
python/sglang/srt/layers/moe/moe_runner/triton_kernels.py
+194
-0
python/sglang/srt/layers/moe/token_dispatcher/base.py
python/sglang/srt/layers/moe/token_dispatcher/base.py
+6
-0
python/sglang/srt/layers/moe/token_dispatcher/standard.py
python/sglang/srt/layers/moe/token_dispatcher/standard.py
+1
-1
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+4
-4
python/sglang/srt/layers/moe/utils.py
python/sglang/srt/layers/moe/utils.py
+3
-3
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+32
-38
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+31
-46
test/srt/test_triton_fused_moe.py
test/srt/test_triton_fused_moe.py
+43
-15
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
62eff37b
...
...
@@ -172,7 +172,7 @@ class FusedMoE(torch.nn.Module):
self
.
reduce_results
=
reduce_results
self
.
use_presharded_weights
=
use_presharded_weights
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
s
()
self
.
quant_config
=
quant_config
self
.
use_flashinfer_mxfp4_moe
=
get_moe_runner_backend
().
is_flashinfer_mxfp4
()
...
...
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
View file @
62eff37b
...
...
@@ -47,7 +47,7 @@ def triton_kernel_moe_forward(
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
assert
TopKOutputChecker
.
format_is_triton_kernel
(
topk_output
)
assert
TopKOutputChecker
.
format_is_triton_kernel
s
(
topk_output
)
routing_data
,
gather_idx
,
scatter_idx
=
topk_output
...
...
@@ -172,6 +172,7 @@ def triton_kernel_moe_with_bias_forward(
b2
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
moe_runner_config
:
MoeRunnerConfig
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
...
...
@@ -184,7 +185,7 @@ def triton_kernel_moe_with_bias_forward(
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
assert
TopKOutputChecker
.
format_is_triton_kernel
(
topk_output
)
assert
TopKOutputChecker
.
format_is_triton_kernel
s
(
topk_output
)
routing_data
,
gather_idx
,
scatter_idx
=
topk_output
...
...
@@ -201,6 +202,7 @@ def triton_kernel_moe_with_bias_forward(
scatter_indx
=
scatter_idx
,
inplace
=
False
,
# triton kernel doesn't support inplace
activation
=
moe_runner_config
.
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
use_fp8_w8a8
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
...
...
@@ -228,6 +230,7 @@ def triton_kernel_fused_experts_with_bias(
scatter_indx
:
ScatterIndx
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
...
...
@@ -296,7 +299,7 @@ def triton_kernel_fused_experts_with_bias(
routing_data
,
gather_indx
=
gather_indx
,
precision_config
=
w1_pcg
,
gammas
=
None
,
gammas
=
routing_data
.
gate_scal
if
apply_router_weight_on_input
else
None
,
fused_activation
=
act
,
)
...
...
@@ -307,5 +310,5 @@ def triton_kernel_fused_experts_with_bias(
routing_data
,
scatter_indx
=
scatter_indx
,
precision_config
=
w2_pcg
,
gammas
=
routing_data
.
gate_scal
,
gammas
=
None
if
apply_router_weight_on_input
else
routing_data
.
gate_scal
,
)
python/sglang/srt/layers/moe/moe_runner/runner.py
View file @
62eff37b
...
...
@@ -11,6 +11,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
)
from
sglang.srt.layers.moe.moe_runner.deep_gemm
import
DeepGemmRunnerCore
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonRunnerCore
from
sglang.srt.layers.moe.moe_runner.triton_kernels
import
TritonKernelsRunnerCore
from
sglang.srt.layers.moe.utils
import
get_moe_a2a_backend
if
TYPE_CHECKING
:
...
...
@@ -31,6 +32,8 @@ class MoeRunner:
if
runner_backend
.
is_triton
():
self
.
runner_core
=
TritonRunnerCore
(
config
)
elif
runner_backend
.
is_triton_kernels
():
self
.
runner_core
=
TritonKernelsRunnerCore
(
config
)
elif
runner_backend
.
is_deep_gemm
():
self
.
runner_core
=
DeepGemmRunnerCore
(
config
)
else
:
...
...
python/sglang/srt/layers/moe/moe_runner/triton_kernels.py
0 → 100644
View file @
62eff37b
"""Triton kernels MoE runner backend skeleton."""
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
sglang.srt.layers.moe.moe_runner.base
import
(
MoeQuantInfo
,
MoeRunnerConfig
,
MoeRunnerCore
,
RunnerInput
,
RunnerOutput
,
register_post_permute
,
register_pre_permute
,
)
from
sglang.srt.layers.moe.utils
import
MoeRunnerBackend
if
TYPE_CHECKING
:
from
triton_kernels.matmul_ogs
import
PrecisionConfig
from
triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
StandardCombineInput
,
StandardDispatchOutput
,
)
# ---------------------------------------------------------------------------
# Runner IO dataclasses
# ---------------------------------------------------------------------------
@
dataclass
class
TritonKernelsRunnerInput
(
RunnerInput
):
"""Input bundle passed to the triton-kernels runner core."""
hidden_states
:
torch
.
Tensor
routing_data
:
"RoutingData"
gather_indx
:
"GatherIndx"
scatter_indx
:
"ScatterIndx"
@
property
def
runner_backend
(
self
)
->
MoeRunnerBackend
:
return
MoeRunnerBackend
.
TRITON_KERNELS
@
dataclass
class
TritonKernelsRunnerOutput
(
RunnerOutput
):
"""Output bundle returned from the triton-kernels runner core."""
hidden_states
:
torch
.
Tensor
@
property
def
runner_backend
(
self
)
->
MoeRunnerBackend
:
return
MoeRunnerBackend
.
TRITON_KERNELS
@
dataclass
class
TritonKernelsQuantInfo
(
MoeQuantInfo
):
"""Quantization payload consumed by the triton-kernels backend."""
w13_weight
:
torch
.
Tensor
w2_weight
:
torch
.
Tensor
w13_bias
:
Optional
[
torch
.
Tensor
]
=
None
w2_bias
:
Optional
[
torch
.
Tensor
]
=
None
w13_precision_config
:
Optional
[
PrecisionConfig
]
=
None
w2_precision_config
:
Optional
[
PrecisionConfig
]
=
None
global_num_experts
:
int
=
-
1
# ---------------------------------------------------------------------------
# Runner core
# ---------------------------------------------------------------------------
class
TritonKernelsRunnerCore
(
MoeRunnerCore
):
"""Execute MoE experts via the external triton_kernels package."""
def
run
(
self
,
runner_input
:
TritonKernelsRunnerInput
,
quant_info
:
TritonKernelsQuantInfo
,
running_state
:
dict
,
)
->
TritonKernelsRunnerOutput
:
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_fused_experts
,
triton_kernel_fused_experts_with_bias
,
)
hidden_states
=
runner_input
.
hidden_states
common_kwargs
=
dict
(
routing_data
=
runner_input
.
routing_data
,
gather_indx
=
runner_input
.
gather_indx
,
scatter_indx
=
None
if
self
.
config
.
no_combine
else
runner_input
.
scatter_indx
,
inplace
=
False
,
activation
=
self
.
config
.
activation
,
apply_router_weight_on_input
=
self
.
config
.
apply_router_weight_on_input
,
global_num_experts
=
quant_info
.
global_num_experts
,
)
has_bias
=
quant_info
.
w13_bias
is
not
None
or
quant_info
.
w2_bias
is
not
None
if
has_bias
:
assert
(
quant_info
.
w13_bias
is
not
None
and
quant_info
.
w2_bias
is
not
None
),
"Bias execution requires both w13_bias and w2_bias"
output
=
triton_kernel_fused_experts_with_bias
(
hidden_states
=
hidden_states
,
w1
=
quant_info
.
w13_weight
,
w1_pcg
=
quant_info
.
w13_precision_config
,
b1
=
quant_info
.
w13_bias
,
w2
=
quant_info
.
w2_weight
,
w2_pcg
=
quant_info
.
w2_precision_config
,
b2
=
quant_info
.
w2_bias
,
gemm1_alpha
=
self
.
config
.
gemm1_alpha
,
gemm1_clamp_limit
=
self
.
config
.
gemm1_clamp_limit
,
**
common_kwargs
,
)
else
:
output
=
triton_kernel_fused_experts
(
hidden_states
=
hidden_states
,
w1
=
quant_info
.
w13_weight
,
w2
=
quant_info
.
w2_weight
,
**
common_kwargs
,
)
if
self
.
config
.
no_combine
:
tokens
=
runner_input
.
hidden_states
.
shape
[
0
]
hidden
=
runner_input
.
hidden_states
.
shape
[
-
1
]
total_rows
=
output
.
shape
[
0
]
top_k
=
total_rows
//
tokens
output
=
output
.
view
(
tokens
,
top_k
,
hidden
)
return
TritonKernelsRunnerOutput
(
hidden_states
=
output
)
@
property
def
runner_backend
(
self
)
->
MoeRunnerBackend
:
return
MoeRunnerBackend
.
TRITON_KERNELS
# ---------------------------------------------------------------------------
# Permute / fused hooks
# ---------------------------------------------------------------------------
@
register_pre_permute
(
"standard"
,
"triton_kernel"
)
def
pre_permute_standard_to_triton_kernels
(
dispatch_output
:
"StandardDispatchOutput"
,
quant_info
:
TritonKernelsQuantInfo
,
runner_config
:
MoeRunnerConfig
,
running_state
:
dict
,
)
->
TritonKernelsRunnerInput
:
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
hidden_states
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
assert
TopKOutputChecker
.
format_is_triton_kernels
(
topk_output
),
"Triton-kernel runner expects TritonKernelTopKOutput"
routing_data
,
gather_indx
,
scatter_indx
=
topk_output
return
TritonKernelsRunnerInput
(
hidden_states
=
hidden_states
,
routing_data
=
routing_data
,
gather_indx
=
gather_indx
,
scatter_indx
=
scatter_indx
,
)
@
register_post_permute
(
"triton_kernel"
,
"standard"
)
def
post_permute_triton_kernels_to_standard
(
runner_output
:
TritonKernelsRunnerOutput
,
quant_info
:
TritonKernelsQuantInfo
,
runner_config
:
MoeRunnerConfig
,
running_state
:
dict
,
)
->
StandardCombineInput
:
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardCombineInput
hidden_states
=
runner_output
.
hidden_states
if
(
runner_config
.
routed_scaling_factor
is
not
None
and
runner_config
.
routed_scaling_factor
!=
1.0
and
not
runner_config
.
no_combine
):
hidden_states
.
mul_
(
runner_config
.
routed_scaling_factor
)
return
StandardCombineInput
(
hidden_states
=
hidden_states
)
python/sglang/srt/layers/moe/token_dispatcher/base.py
View file @
62eff37b
...
...
@@ -28,6 +28,12 @@ class DispatchOutputChecker:
)
->
TypeGuard
[
StandardDispatchOutput
]:
return
dispatch_output
.
format
.
is_standard
()
@
staticmethod
def
format_is_triton_kernels
(
dispatch_output
:
DispatchOutput
,
)
->
TypeGuard
[
StandardDispatchOutput
]:
return
dispatch_output
.
format
.
is_standard
()
@
staticmethod
def
format_is_deepep_normal
(
dispatch_output
:
DispatchOutput
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/standard.py
View file @
62eff37b
...
...
@@ -88,7 +88,7 @@ class StandardDispatcher(BaseDispatcher):
topk_output
=
topk_output
.
_replace
(
topk_ids
=
self
.
local_expert_mapping
[
topk_output
.
topk_ids
]
)
elif
TopKOutputChecker
.
format_is_triton_kernel
(
topk_output
):
elif
TopKOutputChecker
.
format_is_triton_kernel
s
(
topk_output
):
raise
NotImplementedError
()
return
StandardDispatchOutput
(
...
...
python/sglang/srt/layers/moe/topk.py
View file @
62eff37b
...
...
@@ -111,10 +111,10 @@ class TopKOutputChecker:
return
topk_output
.
format
.
is_standard
()
@
staticmethod
def
format_is_triton_kernel
(
def
format_is_triton_kernel
s
(
topk_output
:
TopKOutput
,
)
->
TypeGuard
[
TritonKernelTopKOutput
]:
return
topk_output
.
format
.
is_triton_kernel
()
return
topk_output
.
format
.
is_triton_kernel
s
()
@
staticmethod
def
format_is_bypassed
(
topk_output
:
TopKOutput
)
->
TypeGuard
[
BypassedTopKOutput
]:
...
...
@@ -129,7 +129,7 @@ class TopKOutputFormat(Enum):
def
is_standard
(
self
)
->
bool
:
return
self
==
TopKOutputFormat
.
STANDARD
def
is_triton_kernel
(
self
)
->
bool
:
def
is_triton_kernel
s
(
self
)
->
bool
:
return
self
==
TopKOutputFormat
.
TRITON_KERNEL
def
is_bypassed
(
self
)
->
bool
:
...
...
@@ -254,7 +254,7 @@ class TopK(CustomOp):
)
->
TopKOutput
:
if
self
.
topk_config
.
output_format
is
not
None
:
output_format
=
self
.
topk_config
.
output_format
elif
get_moe_runner_backend
().
is_triton_kernel
():
elif
get_moe_runner_backend
().
is_triton_kernel
s
():
output_format
=
TopKOutputFormat
.
TRITON_KERNEL
elif
(
should_use_flashinfer_trtllm_moe
()
...
...
python/sglang/srt/layers/moe/utils.py
View file @
62eff37b
...
...
@@ -51,7 +51,7 @@ class MoeRunnerBackend(Enum):
AUTO
=
"auto"
DEEP_GEMM
=
"deep_gemm"
TRITON
=
"triton"
TRITON_KERNEL
=
"triton_kernel"
TRITON_KERNEL
S
=
"triton_kernel"
FLASHINFER_TRTLLM
=
"flashinfer_trtllm"
FLASHINFER_CUTLASS
=
"flashinfer_cutlass"
FLASHINFER_MXFP4
=
"flashinfer_mxfp4"
...
...
@@ -67,8 +67,8 @@ class MoeRunnerBackend(Enum):
def
is_triton
(
self
):
return
self
==
MoeRunnerBackend
.
TRITON
def
is_triton_kernel
(
self
):
return
self
==
MoeRunnerBackend
.
TRITON_KERNEL
def
is_triton_kernel
s
(
self
):
return
self
==
MoeRunnerBackend
.
TRITON_KERNEL
S
def
is_flashinfer_trtllm
(
self
):
return
self
==
MoeRunnerBackend
.
FLASHINFER_TRTLLM
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
62eff37b
...
...
@@ -261,26 +261,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self
.
prefix
=
prefix
self
.
topk_indices_dtype
=
None
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
s
()
self
.
with_bias
=
False
self
.
use_flashinfer
=
get_moe_runner_backend
().
is_flashinfer_mxfp4
()
self
.
flashinfer_mxfp4_moe_precision
=
(
get_global_server_args
().
flashinfer_mxfp4_moe_precision
)
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
if
torch
.
cuda
.
is_available
()
and
has_triton_kernels
:
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
as
_tk_forward
,
)
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_with_bias_forward
as
_tk_with_bias_forward
,
)
self
.
triton_kernel_moe_forward
=
_tk_forward
self
.
triton_kernel_moe_with_bias_forward
=
_tk_with_bias_forward
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -600,7 +587,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
backend
=
get_moe_runner_backend
()
if
backend
.
is_auto
():
backend
=
(
MoeRunnerBackend
.
TRITON_KERNELS
if
self
.
use_triton_kernels
else
MoeRunnerBackend
.
TRITON
)
self
.
runner
=
MoeRunner
(
backend
,
moe_runner_config
)
def
apply
(
self
,
...
...
@@ -677,31 +671,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)[
0
]
return
StandardCombineInput
(
hidden_states
=
trtllm_gen_output
)
if
self
.
use_triton_kernels
:
backend
=
self
.
runner
.
runner_backend
if
backend
.
is_triton_kernels
():
from
sglang.srt.layers.moe.moe_runner.triton_kernels
import
(
TritonKernelsQuantInfo
,
)
assert
(
layer
.
moe_ep_size
==
1
),
"Expert parallel is not supported when using triton kernels"
if
self
.
with_bias
:
output
=
self
.
triton_kernel_moe_with_bias_forward
(
hidden_states
=
x
,
w1
=
self
.
w13_weight_triton_tensor
,
w1_pcg
=
self
.
w13_precision_config
,
w2
=
self
.
w2_weight_triton_tensor
,
w2_pcg
=
self
.
w2_precision_config
,
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
)
else
:
output
=
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
)
return
StandardCombineInput
(
hidden_states
=
output
)
quant_info
=
TritonKernelsQuantInfo
(
w13_weight
=
(
self
.
w13_weight_triton_tensor
if
self
.
w13_weight_triton_tensor
is
not
None
else
layer
.
w13_weight
),
w2_weight
=
(
self
.
w2_weight_triton_tensor
if
self
.
w2_weight_triton_tensor
is
not
None
else
layer
.
w2_weight
),
w13_bias
=
getattr
(
layer
,
"w13_weight_bias"
,
None
),
w2_bias
=
getattr
(
layer
,
"w2_weight_bias"
,
None
),
w13_precision_config
=
getattr
(
self
,
"w13_precision_config"
,
None
),
w2_precision_config
=
getattr
(
self
,
"w2_precision_config"
,
None
),
)
else
:
quant_info
=
TritonMoeQuantInfo
(
w13_weight
=
layer
.
w13_weight
,
...
...
@@ -709,7 +703,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b13
=
getattr
(
layer
,
"w13_weight_bias"
,
None
),
b2
=
getattr
(
layer
,
"w2_weight_bias"
,
None
),
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
class
Mxfp4DynamicQuantMoEMethod
(
FusedMoEMethodBase
):
...
...
python/sglang/srt/layers/quantization/unquant.py
View file @
62eff37b
...
...
@@ -8,7 +8,12 @@ from torch.nn.parameter import Parameter
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.layers.amx_utils
import
_amx_process_weight_after_loading
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe
import
(
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
,
get_moe_runner_backend
,
)
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
...
...
@@ -115,13 +120,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
use_intel_amx_backend
(
layer
):
x_shapes
=
x
.
shape
if
len
(
x_shapes
)
==
3
:
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
=
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
x
,
layer
.
weight
,
bias
,
True
# is_vnni
x
,
layer
.
weight
,
bias
,
True
,
# is_vnni
)
if
len
(
x_shapes
)
==
3
:
output
=
output
.
view
(
x_shapes
[
0
],
x_shapes
[
1
],
-
1
)
...
...
@@ -138,19 +145,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self
.
use_triton_kernels
=
use_triton_kernels
self
.
with_bias
=
False
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
if
torch
.
cuda
.
is_available
()
and
use_triton_kernels
:
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
as
_tk_forward
,
)
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_with_bias_forward
as
_tk_with_bias_forward
,
)
self
.
triton_kernel_moe_forward
=
_tk_forward
self
.
triton_kernel_moe_with_bias_forward
=
_tk_with_bias_forward
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -231,14 +225,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
backend
=
get_moe_runner_backend
()
if
backend
.
is_auto
():
backend
=
(
MoeRunnerBackend
.
TRITON_KERNELS
if
self
.
use_triton_kernels
else
MoeRunnerBackend
.
TRITON
)
self
.
runner
=
MoeRunner
(
backend
,
moe_runner_config
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
return
self
.
forward
(
layer
=
layer
,
dispatch_output
=
dispatch_output
,
...
...
@@ -249,7 +249,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
...
...
@@ -257,30 +256,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
moe_runner_config
=
self
.
moe_runner_config
if
self
.
use_triton_kernels
:
if
self
.
with_bias
:
assert
self
.
triton_kernel_moe_with_bias_forward
is
not
None
output
=
self
.
triton_kernel_moe_with_bias_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
w1_pcg
=
None
,
w2_pcg
=
None
,
)
else
:
assert
self
.
triton_kernel_moe_forward
is
not
None
output
=
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
moe_runner_config
=
moe_runner_config
,
)
return
StandardCombineInput
(
hidden_states
=
output
)
backend
=
self
.
runner
.
runner_backend
if
backend
.
is_triton_kernels
():
from
sglang.srt.layers.moe.moe_runner.triton_kernels
import
(
TritonKernelsQuantInfo
,
)
quant_info
=
TritonKernelsQuantInfo
(
w13_weight
=
layer
.
w13_weight
,
w2_weight
=
layer
.
w2_weight
,
w13_bias
=
getattr
(
layer
,
"w13_weight_bias"
,
None
),
w2_bias
=
getattr
(
layer
,
"w2_weight_bias"
,
None
),
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
else
:
if
_use_aiter
:
assert
not
moe_runner_config
.
no_combine
,
"unsupported"
...
...
@@ -311,7 +299,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
return
StandardCombineInput
(
hidden_states
=
output
)
else
:
quant_info
=
TritonMoeQuantInfo
(
w13_weight
=
layer
.
w13_weight
,
w2_weight
=
layer
.
w2_weight
,
...
...
@@ -325,7 +312,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
...
...
@@ -380,7 +366,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
import
torch_npu
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
...
...
test/srt/test_triton_fused_moe.py
View file @
62eff37b
...
...
@@ -5,11 +5,10 @@ import torch.nn.functional as F
from
tqdm
import
tqdm
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
,
)
from
sglang.srt.layers.moe.moe_runner
import
MoeRunnerConfig
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.triton_kernels
import
TritonKernelsQuantInfo
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardDispatchOutput
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -55,6 +54,7 @@ class TestFusedMOE(CustomTestCase):
w2
,
score
,
topk
,
return_per_expert
:
bool
=
False
,
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
...
...
@@ -78,9 +78,14 @@ class TestFusedMOE(CustomTestCase):
a
[
mask
]
@
w1_compute
[
i
].
transpose
(
0
,
1
)
)
@
w2_compute
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
weighted
=
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
if
return_per_expert
:
return
weighted
return
weighted
.
sum
(
dim
=
1
)
def
_test_case
(
self
,
m
,
n
,
k
,
e
,
topk
,
dtype
):
rtol
,
atol
=
self
.
get_tolerance
(
dtype
)
...
...
@@ -99,20 +104,43 @@ class TestFusedMOE(CustomTestCase):
renormalize
=
False
,
use_grouped_topk
=
False
,
)
topk_op
.
use_triton_kernels
=
True
topk_op
.
topk_config
.
output_format
=
TopKOutputFormat
.
TRITON_KERNEL
triton_topk_output
=
topk_op
.
forward_cuda
(
hidden_states
=
a
,
router_logits
=
score
,
)
moe_runner_config
=
MoeRunnerConfig
(
inplace
=
False
,
quant_info
=
TritonKernelsQuantInfo
(
w13_weight
=
w1_tri
,
w2_weight
=
w2_tri
)
dispatch_output
=
StandardDispatchOutput
(
hidden_states
=
a
,
topk_output
=
triton_topk_output
)
torch_per_expert
=
self
.
torch_naive_moe
(
a
,
w1
,
w2
,
score
,
topk
,
return_per_expert
=
True
)
torch_combined
=
torch_per_expert
.
sum
(
dim
=
1
)
def
run_runner
(
config
):
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON_KERNELS
,
config
)
result
=
runner
.
run
(
dispatch_output
,
quant_info
)
return
result
.
hidden_states
# Combined output (no_combine=False)
non_fused_config
=
MoeRunnerConfig
(
inplace
=
False
)
non_fused_output
=
run_runner
(
non_fused_config
)
torch
.
testing
.
assert_close
(
non_fused_output
,
torch_combined
,
rtol
=
rtol
,
atol
=
atol
)
# Per-expert output (no_combine=True)
non_fused_no_combine_config
=
MoeRunnerConfig
(
inplace
=
False
,
no_combine
=
True
,
top_k
=
topk
)
triton_output
=
triton_kernel_moe_forward
(
a
,
w1_tri
,
w2_tri
,
triton_topk_output
,
moe_runner_config
non_fused_no_combine_output
=
run_runner
(
non_fused_no_combine_config
)
torch
.
testing
.
assert_close
(
non_fused_no_combine_output
,
torch_per_expert
,
rtol
=
rtol
,
atol
=
atol
)
torch_output
=
self
.
torch_naive_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
rtol
=
rtol
,
atol
=
atol
)
def
test_various_configurations
(
self
):
m_values
=
[
1
,
32
,
64
,
256
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment