Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
253454de
Unverified
Commit
253454de
authored
Jul 07, 2025
by
Yuan Luo
Committed by
GitHub
Jul 06, 2025
Browse files
Integrate triton moe kernel (#7689)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
ea3e7ffe
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
697 additions
and
54 deletions
+697
-54
benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
...els/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
+271
-0
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+2
-0
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+95
-54
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
...ang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
+176
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
test/srt/test_triton_fused_moe.py
test/srt/test_triton_fused_moe.py
+146
-0
No files found.
benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py
0 → 100644
View file @
253454de
# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8
import
argparse
import
torch
import
triton
from
transformers
import
AutoConfig
from
sglang.srt.distributed.parallel_state
import
(
destroy_distributed_environment
,
destroy_model_parallel
,
init_distributed_environment
,
initialize_model_parallel
,
)
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
fused_moe
as
fused_moe_sglang
,
)
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
,
)
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
"""Get model configuration parameters"""
config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"Qwen3MoeForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]:
E
=
(
config
.
n_routed_experts
+
1
if
config
.
architectures
[
0
]
in
[
"DeepseekV3ForCausalLM"
]
else
config
.
n_routed_experts
)
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
else
:
# Default: Mixtral
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
block_shape
=
None
if
(
hasattr
(
config
,
"quantization_config"
)
and
"weight_block_size"
in
config
.
quantization_config
):
block_shape
=
config
.
quantization_config
[
"weight_block_size"
]
assert
len
(
block_shape
)
==
2
shape_configs
=
{
"num_experts"
:
E
,
"topk"
:
topk
,
"hidden_size"
:
config
.
hidden_size
,
"shard_intermediate_size"
:
shard_intermediate_size
,
"dtype"
:
config
.
torch_dtype
,
"block_shape"
:
block_shape
,
}
print
(
f
"
{
shape_configs
=
}
"
)
return
shape_configs
def
fused_moe_triton_api
(
x
,
w1
,
w2
,
input_gating
,
topk
,
):
return
triton_kernel_moe_forward
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
False
,
)
def
fused_moe_sglang_api
(
x
,
w1
,
w2
,
input_gating
,
topk
,
use_fp8_w8a8
=
False
,
w1_scale
=
None
,
w2_scale
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
block_shape
=
None
,
):
return
fused_moe_sglang
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
False
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
list
([
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]),
line_arg
=
"provider"
,
line_vals
=
[
"sglang_fused_moe_triton_v340"
,
"sglang_fused_moe_triton"
,
],
line_names
=
[
"sglang_fused_moe_triton_v340"
,
"sglang_fused_moe_triton"
,
],
styles
=
[
(
"blue"
,
"-"
),
(
"green"
,
"-"
),
],
ylabel
=
"Time (ms)"
,
plot_name
=
"fused-moe-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
model_config
,
use_fp8_w8a8
=
False
,
use_cuda_graph
:
bool
=
False
,
):
print
(
f
"benchmark
{
provider
}
with batch_size=
{
batch_size
}
"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_tokens
=
batch_size
num_experts
=
model_config
[
"num_experts"
]
hidden_size
=
model_config
[
"hidden_size"
]
shard_intermediate_size
=
model_config
[
"shard_intermediate_size"
]
topk
=
model_config
[
"topk"
]
dtype
=
model_config
[
"dtype"
]
block_shape
=
model_config
[
"block_shape"
]
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
dtype
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
dtype
)
w1_tri
=
w1
.
clone
()
w2_tri
=
w2
.
clone
()
w1_tri
=
w1_tri
.
transpose
(
-
2
,
-
1
).
contiguous
()
w2_tri
=
w2_tri
.
transpose
(
-
2
,
-
1
).
contiguous
()
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
if
provider
==
"sglang_fused_moe_triton_v340"
:
api_func
=
fused_moe_triton_api
api_kwargs
=
{
"x"
:
x
,
"w1"
:
w1_tri
,
"w2"
:
w2_tri
,
"input_gating"
:
input_gating
,
"topk"
:
topk
,
}
else
:
api_func
=
fused_moe_sglang_api
api_kwargs
=
{
"x"
:
x
,
"w1"
:
w1
,
"w2"
:
w2
,
"input_gating"
:
input_gating
,
"topk"
:
topk
,
"use_fp8_w8a8"
:
use_fp8_w8a8
,
"block_shape"
:
block_shape
,
}
# Warmup
for
_
in
range
(
10
):
_
=
api_func
(
**
api_kwargs
)
torch
.
cuda
.
synchronize
()
if
use_cuda_graph
:
stream
=
torch
.
cuda
.
Stream
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
api_func
(
**
api_kwargs
)
torch
.
cuda
.
synchronize
()
bench_lambda
=
lambda
:
graph
.
replay
()
else
:
bench_lambda
=
lambda
:
api_func
(
**
api_kwargs
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
bench_lambda
,
quantiles
=
quantiles
)
return
ms
,
min_ms
,
max_ms
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--tp-size"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--use-fp8-w8a8"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use-cuda-graph"
,
action
=
"store_true"
,
help
=
"Enable CUDA Graph capture/replay"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/sglang_fused_moe/"
,
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
try
:
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
,
init_method
=
"tcp://127.0.0.1:23456"
,
world_size
=
1
,
rank
=
0
,
)
init_distributed_environment
(
world_size
=
1
,
rank
=
0
,
distributed_init_method
=
"tcp://127.0.0.1:23456"
,
local_rank
=
0
,
backend
=
"nccl"
if
torch
.
cuda
.
is_available
()
else
"gloo"
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
1
,
pipeline_model_parallel_size
=
1
,
)
model_config
=
get_model_config
(
args
.
model
,
args
.
tp_size
)
benchmark
.
run
(
show_plots
=
True
,
print_data
=
True
,
save_path
=
args
.
save_path
,
model_config
=
model_config
,
use_fp8_w8a8
=
args
.
use_fp8_w8a8
,
use_cuda_graph
=
args
.
use_cuda_graph
,
)
finally
:
destroy_model_parallel
()
destroy_distributed_environment
()
if
__name__
==
"__main__"
:
main
()
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
253454de
...
@@ -1737,6 +1737,7 @@ def fused_moe(
...
@@ -1737,6 +1737,7 @@ def fused_moe(
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
num_fused_shared_experts
:
int
=
0
,
...
@@ -1822,6 +1823,7 @@ def fused_moe(
...
@@ -1822,6 +1823,7 @@ def fused_moe(
topk_ids
,
topk_ids
,
inplace
=
inplace
,
inplace
=
inplace
,
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
253454de
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import
importlib
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
...
@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
cpu_has_amx_support
,
...
@@ -29,8 +31,15 @@ from sglang.srt.utils import (
...
@@ -29,8 +31,15 @@ from sglang.srt.utils import (
use_intel_amx_backend
,
use_intel_amx_backend
,
)
)
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
if
has_triton_kernels
:
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
,
)
else
:
else
:
fused_experts
=
None
# type: ignore
fused_experts
=
None
# type: ignore
...
@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
"""MoE method without quantization."""
def
__init__
(
self
,
use_triton_kernels
:
bool
=
False
):
super
().
__init__
()
self
.
use_triton_kernels
=
use_triton_kernels
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
# Fused gate_up_proj (column parallel)
# Fused gate_up_proj (column parallel)
w13_weight_n
,
w13_weight_k
=
2
*
intermediate_size
,
hidden_size
if
self
.
use_triton_kernels
:
w13_weight_n
,
w13_weight_k
=
w13_weight_k
,
w13_weight_n
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
w13_weight_n
,
w13_weight_k
,
dtype
=
params_dtype
),
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel)
# down_proj (row parallel)
w2_weight_n
,
w2_weight_k
=
(
hidden_size
,
intermediate_size
,
)
if
self
.
use_triton_kernels
:
w2_weight_n
,
w2_weight_k
=
w2_weight_k
,
w2_weight_n
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
w2_weight_n
,
w2_weight_k
,
dtype
=
params_dtype
),
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
...
@@ -192,6 +210,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -192,6 +210,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
use_triton_kernels
:
return
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
)
else
:
topk_weights
,
topk_ids
=
select_experts
(
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
@@ -228,7 +257,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -228,7 +257,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
activation
=
(
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
),
),
)
)
else
:
else
:
...
@@ -475,9 +506,13 @@ class FusedMoE(torch.nn.Module):
...
@@ -475,9 +506,13 @@ class FusedMoE(torch.nn.Module):
self
.
inplace
=
inplace
self
.
inplace
=
inplace
self
.
no_combine
=
no_combine
self
.
no_combine
=
no_combine
self
.
use_triton_kernels
=
(
not
_is_cpu
and
global_server_args_dict
[
"enable_triton_kernel_moe"
]
)
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
UnquantizedFusedMoEMethod
()
self
.
use_triton_kernels
)
)
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
...
@@ -597,6 +632,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -597,6 +632,8 @@ class FusedMoE(torch.nn.Module):
)
)
else
:
else
:
if
not
self
.
use_presharded_weights
:
if
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
narrow
(
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
)
...
@@ -630,6 +667,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -630,6 +667,8 @@ class FusedMoE(torch.nn.Module):
)
)
else
:
else
:
if
not
self
.
use_presharded_weights
:
if
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
narrow
(
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
)
...
@@ -716,6 +755,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -716,6 +755,8 @@ class FusedMoE(torch.nn.Module):
# should be whatever dimension intermediate_size is
# should be whatever dimension intermediate_size is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
self
.
use_triton_kernels
:
is_transposed
=
True
if
is_transposed
:
if
is_transposed
:
shard_dim
=
int
(
not
shard_dim
)
shard_dim
=
int
(
not
shard_dim
)
...
...
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
0 → 100644
View file @
253454de
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
from
typing
import
Optional
import
torch
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
triton_kernels.matmul_ogs
import
matmul_ogs
from
triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
,
routing
from
sglang.srt.utils
import
direct_register_custom_op
def
triton_kernel_moe_forward
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
if
not
renormalize
:
gating_output
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
routing_data
,
gather_idx
,
scatter_idx
=
routing
(
gating_output
,
topk
,
renormalize
)
return
triton_kernel_fused_experts
(
hidden_states
,
w1
,
w2
,
routing_data
,
gather_idx
,
scatter_idx
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
use_fp8_w8a8
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)
# This is a triton implementation of the fused_experts function
def
triton_kernel_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
routing_data
:
RoutingData
,
gather_indx
:
GatherIndx
,
scatter_indx
:
ScatterIndx
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
assert
use_fp8_w8a8
==
False
,
"use_fp8_w8a8 is not supported"
assert
per_channel_quant
==
False
,
"per_channel_quant is not supported"
assert
expert_map
==
None
,
"expert_map is not supported"
assert
w1_scale
==
None
,
"w1_scale is not supported"
assert
w2_scale
==
None
,
"w2_scale is not supported"
assert
a1_scale
==
None
,
"a1_scale is not supported"
assert
a2_scale
==
None
,
"a2_scale is not supported"
assert
block_shape
==
None
,
"block_shape is not supported"
# type check
assert
hidden_states
.
dtype
==
torch
.
bfloat16
,
"hidden_states must be bfloat16"
assert
w1
.
dtype
==
torch
.
bfloat16
,
"w1 must be bfloat16"
assert
w2
.
dtype
==
torch
.
bfloat16
,
"w2 must be bfloat16"
# Shape check
assert
hidden_states
.
ndim
==
2
,
"hidden_states must be 2D"
assert
(
hidden_states
.
shape
[
-
1
]
==
w1
.
shape
[
-
2
]
),
f
"hidden_states shape[-1]
{
hidden_states
.
shape
}
must be equal to w1 shape[-2]
{
w1
.
shape
}
"
assert
(
w2
.
shape
[
-
1
]
==
w1
.
shape
[
1
]
),
f
"w2 shape[-1]
{
w2
.
shape
[
-
1
]
}
must be equal to w1 shape[1]
{
w1
.
shape
[
1
]
}
"
# feature check
assert
inplace
==
False
,
"Inplace is not supported in new triton MoE kernel"
M
,
K
=
hidden_states
.
shape
E
,
_
,
N
=
w1
.
shape
n_expts_act
=
routing_data
.
n_expts_act
dtype
=
hidden_states
.
dtype
if
global_num_experts
==
-
1
:
global_num_experts
=
E
# consistent with default implementation
intermediate_cache2
=
torch
.
empty
(
(
M
*
n_expts_act
,
N
//
2
),
device
=
"cuda"
,
dtype
=
dtype
)
intermediate_cache1
=
matmul_ogs
(
hidden_states
,
w1
,
None
,
routing_data
,
gather_indx
=
gather_indx
,
gammas
=
routing_data
.
gate_scal
if
apply_router_weight_on_input
else
None
,
)
if
activation
==
"silu"
:
silu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
elif
activation
==
"gelu"
:
gelu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
intermediate_cache3
=
matmul_ogs
(
intermediate_cache2
,
w2
,
None
,
routing_data
,
scatter_indx
=
scatter_indx
,
gammas
=
None
if
apply_router_weight_on_input
else
routing_data
.
gate_scal
,
)
return
intermediate_cache3
def
triton_kernel_moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"forward_cuda_triton"
,
op_func
=
triton_kernel_moe_forward
,
mutates_args
=
[],
fake_impl
=
triton_kernel_moe_forward_fake
,
)
python/sglang/srt/managers/schedule_batch.py
View file @
253454de
...
@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"triton_attention_reduce_in_fp32"
,
"triton_attention_reduce_in_fp32"
,
"num_reserved_decode_tokens"
,
"num_reserved_decode_tokens"
,
"weight_loader_disable_mmap"
,
"weight_loader_disable_mmap"
,
"enable_triton_kernel_moe"
,
]
]
# Put some global args for easy access
# Put some global args for easy access
...
...
python/sglang/srt/server_args.py
View file @
253454de
...
@@ -222,6 +222,7 @@ class ServerArgs:
...
@@ -222,6 +222,7 @@ class ServerArgs:
disable_chunked_prefix_cache
:
bool
=
False
disable_chunked_prefix_cache
:
bool
=
False
disable_fast_image_processor
:
bool
=
False
disable_fast_image_processor
:
bool
=
False
enable_return_hidden_states
:
bool
=
False
enable_return_hidden_states
:
bool
=
False
enable_triton_kernel_moe
:
bool
=
False
warmups
:
Optional
[
str
]
=
None
warmups
:
Optional
[
str
]
=
None
# Debug tensor dumps
# Debug tensor dumps
...
@@ -1554,6 +1555,11 @@ class ServerArgs:
...
@@ -1554,6 +1555,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enable returning hidden states with responses."
,
help
=
"Enable returning hidden states with responses."
,
)
)
parser
.
add_argument
(
"--enable-triton-kernel-moe"
,
action
=
"store_true"
,
help
=
"Use triton moe grouped gemm kernel."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--warmups"
,
"--warmups"
,
type
=
str
,
type
=
str
,
...
...
test/srt/test_triton_fused_moe.py
0 → 100644
View file @
253454de
import
unittest
import
torch
import
torch.nn.functional
as
F
from
tqdm
import
tqdm
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
,
)
from
sglang.test.test_utils
import
CustomTestCase
class
TestFusedMOE
(
CustomTestCase
):
NUM_EXPERTS
=
[
8
,
64
]
TOP_KS
=
[
2
,
4
]
@
staticmethod
def
create_random_cuda_tensor
(
shape
,
dtype
,
mean
=
0
,
std
=
0.01
):
"""Create a random CUDA tensor
Args:
shape: Tensor shape
dtype: Data type
mean: Mean value
std: Standard deviation
Returns:
torch.Tensor: Randomly initialized CUDA tensor
"""
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
,
std
)
def
get_tolerance
(
self
,
dtype
):
"""Get tolerance values for different data types
Args:
dtype: Data type
Returns:
tuple: (relative tolerance, absolute tolerance)
"""
if
dtype
==
torch
.
float32
:
return
1e-5
,
1e-5
elif
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
return
1e-5
,
1e-5
else
:
return
1e-2
,
1e-2
# Default values for other types
def
torch_naive_moe
(
self
,
a
,
w1
,
w2
,
score
,
topk
,
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
if
w1
.
dtype
==
torch
.
float8_e4m3fn
:
w1_compute
=
w1
.
to
(
a
.
dtype
)
w2_compute
=
w2
.
to
(
a
.
dtype
)
else
:
w1_compute
=
w1
w2_compute
=
w2
for
i
in
range
(
w1_compute
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1_compute
[
i
].
transpose
(
0
,
1
)
)
@
w2_compute
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
def
_test_case
(
self
,
m
,
n
,
k
,
e
,
topk
,
dtype
):
rtol
,
atol
=
self
.
get_tolerance
(
dtype
)
a
=
self
.
create_random_cuda_tensor
((
m
,
k
),
dtype
)
w1
=
self
.
create_random_cuda_tensor
((
e
,
2
*
n
,
k
),
dtype
)
w2
=
self
.
create_random_cuda_tensor
((
e
,
k
,
n
),
dtype
)
w1_tri
=
w1
.
clone
()
w2_tri
=
w2
.
clone
()
w1_tri
=
w1_tri
.
transpose
(
-
2
,
-
1
).
contiguous
()
w2_tri
=
w2_tri
.
transpose
(
-
2
,
-
1
).
contiguous
()
score
=
self
.
create_random_cuda_tensor
((
m
,
e
),
dtype
)
triton_output
=
triton_kernel_moe_forward
(
a
,
w1_tri
,
w2_tri
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
self
.
torch_naive_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
rtol
=
rtol
,
atol
=
atol
)
def
test_various_configurations
(
self
):
m_values
=
[
1
,
32
,
64
,
256
]
n_values
=
[
128
,
1024
]
k_values
=
[
128
,
512
,
1024
]
dtypes
=
[
torch
.
bfloat16
]
# Calculate total number of tests
total_tests
=
(
len
(
m_values
)
*
len
(
n_values
)
*
len
(
k_values
)
*
len
(
self
.
NUM_EXPERTS
)
*
len
(
self
.
TOP_KS
)
*
len
(
dtypes
)
)
# Create progress bar
with
tqdm
(
total
=
total_tests
,
desc
=
"Running MoE tests"
)
as
pbar
:
for
m
in
m_values
:
for
n
in
n_values
:
for
k
in
k_values
:
for
e
in
self
.
NUM_EXPERTS
:
for
topk
in
self
.
TOP_KS
:
for
dtype
in
dtypes
:
with
self
.
subTest
(
m
=
m
,
n
=
n
,
k
=
k
,
e
=
e
,
topk
=
topk
,
dtype
=
dtype
,
):
self
.
_test_case
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
)
torch
.
cuda
.
empty_cache
()
pbar
.
update
(
1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment