Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
84727a51
Unverified
Commit
84727a51
authored
Jun 12, 2025
by
Yuan Luo
Committed by
GitHub
Jun 11, 2025
Browse files
[sgl-kernel] Add cuda kernel for moe_ep_silu_and_mul (#6919)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
ef326774
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
381 additions
and
0 deletions
+381
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/benchmark/bench_moe_silu_and_mul.py
sgl-kernel/benchmark/bench_moe_silu_and_mul.py
+92
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+4
-0
sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu
sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu
+115
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+18
-0
sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py
sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py
+142
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
84727a51
...
@@ -237,6 +237,7 @@ set(SOURCES
...
@@ -237,6 +237,7 @@ set(SOURCES
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/moe/ep_moe_reorder_kernel.cu"
"csrc/moe/ep_moe_reorder_kernel.cu"
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/packbit.cu"
...
...
sgl-kernel/benchmark/bench_moe_silu_and_mul.py
0 → 100644
View file @
84727a51
import
itertools
import
torch
import
triton
from
sgl_kernel
import
ep_moe_silu_and_mul
from
sglang.srt.layers.moe.ep_moe.kernels
import
silu_and_mul_triton_kernel
batch_size_range
=
[
64
,
128
,
256
,
512
,
640
,
768
,
1024
,
2048
,
4096
]
hidden_size_range
=
[
1024
,
2048
,
4096
,
8192
]
block_size_range
=
[
128
,
256
,
512
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
hidden_size_range
,
block_size_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"hidden_size"
,
"block_size"
],
x_vals
=
[
list
(
cfg
)
for
cfg
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"cuda"
,
"triton"
],
line_names
=
[
"CUDA Kernel"
,
"Triton Kernel"
],
styles
=
[(
"green"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"ep-moe-silu-and-mul-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
hidden_size
,
block_size
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
half_hidden_size
=
hidden_size
//
2
start_expert_id
,
end_expert_id
=
0
,
255
block_size
=
512
quantiles
=
[
0.5
,
0.2
,
0.8
]
def
alloc_tensors
():
gateup_output
=
torch
.
randn
(
batch_size
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
down_input
=
torch
.
empty
(
batch_size
,
half_hidden_size
,
dtype
=
dtype
,
device
=
device
)
reorder_topk_ids
=
torch
.
randint
(
start_expert_id
,
end_expert_id
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
scales
=
torch
.
rand
(
end_expert_id
-
start_expert_id
+
1
,
dtype
=
torch
.
float32
,
device
=
device
)
return
gateup_output
,
down_input
,
reorder_topk_ids
,
scales
if
provider
==
"cuda"
:
gateup
,
down
,
ids
,
scales
=
alloc_tensors
()
def
run_cuda
():
ep_moe_silu_and_mul
(
gateup
,
down
,
ids
,
scales
,
start_expert_id
,
end_expert_id
,
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
run_cuda
,
quantiles
=
quantiles
)
elif
provider
==
"triton"
:
gateup
,
down
,
ids
,
scales
=
alloc_tensors
()
def
run_triton
():
silu_and_mul_triton_kernel
[(
batch_size
,)](
gateup
.
view
(
-
1
),
down
.
view
(
-
1
),
hidden_size
,
ids
,
scales
,
start_expert_id
,
end_expert_id
,
block_size
,
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
run_triton
,
quantiles
=
quantiles
)
else
:
raise
ValueError
(
f
"Unknown provider:
{
provider
}
"
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/common_extension.cc
View file @
84727a51
...
@@ -177,6 +177,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -177,6 +177,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()"
);
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()"
);
m
.
impl
(
"ep_moe_pre_reorder"
,
torch
::
kCUDA
,
&
ep_moe_pre_reorder
);
m
.
impl
(
"ep_moe_pre_reorder"
,
torch
::
kCUDA
,
&
ep_moe_pre_reorder
);
m
.
def
(
"ep_moe_silu_and_mul(Tensor gateup_output, Tensor down_input, Tensor reorder_topk_ids, Tensor scales, int "
"start_expert_id, int end_expert_id) -> ()"
);
m
.
impl
(
"ep_moe_silu_and_mul"
,
torch
::
kCUDA
,
&
ep_moe_silu_and_mul
);
m
.
def
(
m
.
def
(
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()"
);
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()"
);
...
...
sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu
0 → 100644
View file @
84727a51
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <THC/THCAtomics.cuh>
#include <algorithm>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
using
namespace
flashinfer
;
template
<
typename
scalar_t
>
__device__
inline
scalar_t
silu_quantize
(
float
x
);
template
<
>
__device__
inline
float
silu_quantize
<
float
>
(
float
x
)
{
float
y
=
x
/
(
1.
f
+
__expf
(
-
x
));
return
y
;
}
template
<
>
__device__
inline
__half
silu_quantize
<
__half
>
(
float
x
)
{
float
y
=
x
/
(
1.
f
+
__expf
(
-
x
));
return
__float2half_rn
(
y
);
}
template
<
>
__device__
inline
__nv_bfloat16
silu_quantize
<
__nv_bfloat16
>
(
float
x
)
{
float
y
=
x
/
(
1.
f
+
__expf
(
-
x
));
return
__float2bfloat16_rn
(
y
);
}
template
<
typename
scalar_t
>
__global__
void
ep_moe_act_and_mul_cuda_kernel
(
const
scalar_t
*
__restrict__
gateup_output
,
scalar_t
*
__restrict__
down_input
,
const
int
*
__restrict__
reorder_topk_ids
,
const
float
*
__restrict__
scales
,
int
start_expert_id
,
int
end_expert_id
,
int
hidden_size
)
{
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
using
vec_t
=
flashinfer
::
vec_t
<
scalar_t
,
vec_size
>
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
thread_idx
=
threadIdx
.
x
;
const
int64_t
stride
=
blockDim
.
x
;
const
int
half_hidden_size
=
hidden_size
>>
1
;
const
int
expert_id
=
reorder_topk_ids
[
token_idx
];
if
(
expert_id
<
start_expert_id
||
expert_id
>
end_expert_id
)
return
;
const
scalar_t
*
gate_output_ptr
=
gateup_output
+
static_cast
<
int64_t
>
(
token_idx
)
*
hidden_size
;
const
scalar_t
*
up_output_ptr
=
gate_output_ptr
+
half_hidden_size
;
scalar_t
*
dst_ptr
=
down_input
+
static_cast
<
int64_t
>
(
token_idx
)
*
half_hidden_size
;
scalar_t
scale_q
=
static_cast
<
scalar_t
>
(
scales
?
(
1.
f
/
scales
[
expert_id
-
start_expert_id
])
:
1.
f
);
const
uint32_t
vec_elements
=
half_hidden_size
/
vec_size
;
#pragma unroll 1
for
(
uint32_t
idx
=
thread_idx
;
idx
<
vec_elements
;
idx
+=
stride
)
{
vec_t
gate_vec
,
up_vec
,
out_vec
;
gate_vec
.
load
(
gate_output_ptr
+
idx
*
vec_size
);
up_vec
.
load
(
up_output_ptr
+
idx
*
vec_size
);
#pragma unroll
for
(
uint32_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
float
gate_f
=
static_cast
<
float
>
(
gate_vec
[
i
]);
scalar_t
gate_q
=
silu_quantize
<
scalar_t
>
(
gate_f
);
scalar_t
prod
=
gate_q
*
up_vec
[
i
]
*
scale_q
;
out_vec
[
i
]
=
prod
;
}
out_vec
.
store
(
dst_ptr
+
idx
*
vec_size
);
}
const
int64_t
scalar_start
=
static_cast
<
int64_t
>
(
vec_elements
)
*
vec_size
+
thread_idx
;
#pragma unroll 1
for
(
int64_t
idx
=
scalar_start
;
idx
<
half_hidden_size
;
idx
+=
stride
)
{
float
gate_f
=
static_cast
<
float
>
(
gate_output_ptr
[
idx
]);
scalar_t
gate_q
=
silu_quantize
<
scalar_t
>
(
gate_f
);
dst_ptr
[
idx
]
=
gate_q
*
up_output_ptr
[
idx
]
*
scale_q
;
}
}
void
ep_moe_silu_and_mul
(
torch
::
Tensor
gateup_output
,
torch
::
Tensor
down_input
,
torch
::
Tensor
reorder_topk_ids
,
torch
::
Tensor
scales
,
int64_t
start_expert_id
,
int64_t
end_expert_id
)
{
const
int
total_tokens
=
gateup_output
.
size
(
0
);
const
int
hidden_size
=
gateup_output
.
size
(
1
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
gateup_output
.
scalar_type
(),
scalar_t
,
[
&
]
{
dim3
grid
(
total_tokens
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
const
int
half_hidden_size
=
hidden_size
>>
1
;
uint32_t
threads
=
(
half_hidden_size
+
vec_size
-
1
)
/
vec_size
;
threads
=
std
::
max
<
uint32_t
>
(
threads
,
256
);
threads
=
((
threads
+
31
)
&
~
31U
);
dim3
block
(
std
::
min
(
threads
,
1024U
));
ep_moe_act_and_mul_cuda_kernel
<
scalar_t
><<<
grid
,
block
>>>
(
static_cast
<
scalar_t
*>
(
gateup_output
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
down_input
.
data_ptr
()),
reorder_topk_ids
.
data_ptr
<
int
>
(),
scales
.
defined
()
?
scales
.
data_ptr
<
float
>
()
:
nullptr
,
static_cast
<
int
>
(
start_expert_id
),
static_cast
<
int
>
(
end_expert_id
),
hidden_size
);
return
true
;
});
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
84727a51
...
@@ -266,6 +266,14 @@ void ep_moe_pre_reorder(
...
@@ -266,6 +266,14 @@ void ep_moe_pre_reorder(
int64_t
topk
,
int64_t
topk
,
bool
use_per_token_if_dynamic
);
bool
use_per_token_if_dynamic
);
void
ep_moe_silu_and_mul
(
torch
::
Tensor
gateup_output
,
torch
::
Tensor
down_input
,
torch
::
Tensor
reorder_topk_ids
,
torch
::
Tensor
scales
,
int64_t
start_expert_id
,
int64_t
end_expert_id
);
void
ep_moe_post_reorder
(
void
ep_moe_post_reorder
(
torch
::
Tensor
down_output
,
torch
::
Tensor
down_output
,
torch
::
Tensor
output
,
torch
::
Tensor
output
,
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
84727a51
...
@@ -52,6 +52,7 @@ from sgl_kernel.moe import (
...
@@ -52,6 +52,7 @@ from sgl_kernel.moe import (
cutlass_fp4_group_mm
,
cutlass_fp4_group_mm
,
ep_moe_post_reorder
,
ep_moe_post_reorder
,
ep_moe_pre_reorder
,
ep_moe_pre_reorder
,
ep_moe_silu_and_mul
,
fp8_blockwise_scaled_grouped_mm
,
fp8_blockwise_scaled_grouped_mm
,
moe_align_block_size
,
moe_align_block_size
,
moe_fused_gate
,
moe_fused_gate
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
84727a51
...
@@ -88,6 +88,24 @@ def ep_moe_pre_reorder(
...
@@ -88,6 +88,24 @@ def ep_moe_pre_reorder(
)
)
def
ep_moe_silu_and_mul
(
gateup_output
,
down_input
,
reorder_topk_ids
,
scales
,
start_expert_id
,
end_expert_id
,
):
return
torch
.
ops
.
sgl_kernel
.
ep_moe_silu_and_mul
.
default
(
gateup_output
,
down_input
,
reorder_topk_ids
,
scales
,
start_expert_id
,
end_expert_id
,
)
def
ep_moe_post_reorder
(
def
ep_moe_post_reorder
(
down_output
,
down_output
,
output
,
output
,
...
...
sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py
0 → 100644
View file @
84727a51
import
itertools
import
pytest
import
torch
from
sgl_kernel
import
ep_moe_silu_and_mul
from
sglang.srt.layers.moe.ep_moe.kernels
import
silu_and_mul_triton_kernel
def
create_test_tensors
(
total_tokens
:
int
,
hidden_size
:
int
,
start_expert_id
:
int
,
end_expert_id
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
gateup_output
=
torch
.
randn
(
total_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
reorder_topk_ids
=
torch
.
randint
(
start_expert_id
,
end_expert_id
+
1
,
(
total_tokens
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
num_experts
=
end_expert_id
-
start_expert_id
+
1
scales
=
torch
.
rand
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
*
0.8
+
0.5
half_hidden
=
hidden_size
//
2
down_input
=
torch
.
empty
(
total_tokens
,
half_hidden
,
dtype
=
dtype
,
device
=
device
)
return
gateup_output
,
down_input
,
reorder_topk_ids
,
scales
def
run_cuda_kernel
(
gateup_output
:
torch
.
Tensor
,
down_input
:
torch
.
Tensor
,
reorder_topk_ids
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
start_expert_id
:
int
,
end_expert_id
:
int
,
):
ep_moe_silu_and_mul
(
gateup_output
,
down_input
,
reorder_topk_ids
,
scales
,
start_expert_id
,
end_expert_id
,
)
return
down_input
def
run_triton_kernel
(
gateup_output
:
torch
.
Tensor
,
down_input
:
torch
.
Tensor
,
reorder_topk_ids
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
start_expert_id
:
int
,
end_expert_id
:
int
,
hidden_size
:
int
,
):
total_tokens
=
gateup_output
.
size
(
0
)
block_size
=
512
silu_and_mul_triton_kernel
[(
total_tokens
,)](
gateup_output
,
down_input
,
hidden_size
,
reorder_topk_ids
,
scales
,
start_expert_id
,
end_expert_id
,
block_size
,
)
return
down_input
@
pytest
.
mark
.
parametrize
(
"total_tokens,hidden_size"
,
list
(
itertools
.
product
([
32
,
256
,
1024
],
[
128
,
256
,
512
])),
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
])
def
test_ep_moe_silu_and_mul_vs_triton
(
total_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
):
device
=
torch
.
device
(
"cuda"
)
start_expert_id
=
0
end_expert_id
=
15
(
gateup_output
,
_
,
reorder_topk_ids
,
scales
,
)
=
create_test_tensors
(
total_tokens
,
hidden_size
,
start_expert_id
,
end_expert_id
,
dtype
,
device
,
)
down_input_cuda
=
torch
.
empty
(
total_tokens
,
hidden_size
//
2
,
dtype
=
dtype
,
device
=
device
)
down_input_triton
=
torch
.
empty_like
(
down_input_cuda
)
cuda_output
=
run_cuda_kernel
(
gateup_output
,
down_input_cuda
,
reorder_topk_ids
,
scales
,
start_expert_id
,
end_expert_id
,
)
triton_output
=
run_triton_kernel
(
gateup_output
,
down_input_triton
,
reorder_topk_ids
,
scales
,
start_expert_id
,
end_expert_id
,
hidden_size
,
)
torch
.
testing
.
assert_close
(
cuda_output
,
triton_output
,
rtol
=
1e-5
,
atol
=
1e-5
,
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment