Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
43baba64
Unverified
Commit
43baba64
authored
Jun 05, 2025
by
Yuan Luo
Committed by
GitHub
Jun 05, 2025
Browse files
[EP] Add cuda kernel for moe_ep_post_reorder (#6837)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
0166403c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
377 additions
and
4 deletions
+377
-4
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
+92
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+6
-2
sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu
sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu
+83
-2
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+10
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+22
-0
sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py
sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py
+163
-0
No files found.
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
0 → 100644
View file @
43baba64
import
torch
import
triton
from
sgl_kernel
import
ep_moe_post_reorder
from
sglang.srt.layers.moe.ep_moe.kernels
import
post_reorder_triton_kernel
batch_sizes
=
[
64
,
128
,
256
,
512
,
640
,
768
,
1024
,
2048
,
4096
]
configs
=
[(
bs
,)
for
bs
in
batch_sizes
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"cuda"
,
"triton"
],
line_names
=
[
"CUDA Kernel"
,
"Triton Kernel"
],
styles
=
[(
"green"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"ep-moe-post-reorder-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
hidden_size
,
topk
,
start_expert_id
,
end_expert_id
,
block_size
=
4096
,
8
,
0
,
255
,
512
def
alloc_tensors
():
down_output
=
torch
.
randn
(
batch_size
*
topk
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
output
=
torch
.
zeros
(
batch_size
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
src2dst
=
torch
.
randint
(
0
,
batch_size
*
topk
,
(
batch_size
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
)
topk_ids
=
torch
.
randint
(
start_expert_id
,
end_expert_id
+
1
,
(
batch_size
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
,
)
topk_weights
=
torch
.
rand
(
batch_size
,
topk
,
dtype
=
dtype
,
device
=
device
)
return
down_output
,
output
,
src2dst
,
topk_ids
,
topk_weights
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"cuda"
:
d_out
,
out
,
s2d
,
tk_ids
,
tk_weights
=
alloc_tensors
()
def
run_cuda
():
ep_moe_post_reorder
(
d_out
,
out
,
s2d
,
tk_ids
,
tk_weights
,
start_expert_id
,
end_expert_id
,
topk
,
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
run_cuda
,
quantiles
=
quantiles
)
elif
provider
==
"triton"
:
d_out
,
out
,
s2d
,
tk_ids
,
tk_weights
=
alloc_tensors
()
def
run_triton
():
post_reorder_triton_kernel
[(
batch_size
,)](
d_out
.
view
(
-
1
),
out
.
view
(
-
1
),
s2d
.
view
(
-
1
),
tk_ids
.
view
(
-
1
),
tk_weights
.
view
(
-
1
),
start_expert_id
,
end_expert_id
,
topk
,
hidden_size
,
block_size
,
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
run_triton
,
quantiles
=
quantiles
)
else
:
raise
ValueError
(
f
"Unknown provider:
{
provider
}
"
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/common_extension.cc
View file @
43baba64
...
...
@@ -174,9 +174,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"(Tensor[])"
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
m
.
def
(
"ep_moe_pre_reorder(Tensor input
_ptr
, Tensor gateup_input
_ptr
, Tensor src2dst
_ptr
, Tensor topk_ids
_ptr
, Tensor "
"a1_scales
_ptr
, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()"
);
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()"
);
m
.
impl
(
"ep_moe_pre_reorder"
,
torch
::
kCUDA
,
&
ep_moe_pre_reorder
);
m
.
def
(
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()"
);
m
.
impl
(
"ep_moe_post_reorder"
,
torch
::
kCUDA
,
&
ep_moe_post_reorder
);
m
.
def
(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
...
...
sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu
View file @
43baba64
...
...
@@ -67,6 +67,57 @@ __global__ void ep_pre_reorder_cuda_kernel(
}
}
template
<
typename
scalar_t
>
__global__
void
ep_post_reorder_cuda_kernel
(
const
scalar_t
*
__restrict__
down_output_ptr
,
scalar_t
*
__restrict__
output_ptr
,
const
int
*
__restrict__
src2dst_ptr
,
const
int
*
__restrict__
topk_ids_ptr
,
const
scalar_t
*
__restrict__
topk_weights_ptr
,
int
start_expert_id
,
int
end_expert_id
,
int
topk
,
int
hidden_size
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int
*
token_src2dst
=
src2dst_ptr
+
token_idx
*
topk
;
const
int
*
token_topk_ids
=
topk_ids_ptr
+
token_idx
*
topk
;
const
scalar_t
*
token_topk_weights
=
topk_weights_ptr
+
token_idx
*
topk
;
scalar_t
*
dst_ptr
=
output_ptr
+
static_cast
<
int64_t
>
(
token_idx
)
*
hidden_size
;
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
using
vec_t
=
flashinfer
::
vec_t
<
scalar_t
,
vec_size
>
;
const
int
vec_iters
=
hidden_size
/
vec_size
;
for
(
int
idx
=
tid
;
idx
<
vec_iters
;
idx
+=
blockDim
.
x
)
{
float
acc
[
vec_size
]
=
{
0
};
for
(
int
k
=
0
;
k
<
topk
;
++
k
)
{
const
int
expert_id
=
token_topk_ids
[
k
];
if
(
expert_id
<
start_expert_id
||
expert_id
>
end_expert_id
)
continue
;
const
int
src_row
=
token_src2dst
[
k
];
const
scalar_t
*
src_ptr
=
down_output_ptr
+
static_cast
<
int64_t
>
(
src_row
)
*
hidden_size
;
const
float
weight
=
static_cast
<
float
>
(
token_topk_weights
[
k
]);
vec_t
src_vec
;
src_vec
.
cast_load
(
src_ptr
+
idx
*
vec_size
);
#pragma unroll
for
(
uint32_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
acc
[
i
]
+=
static_cast
<
float
>
(
src_vec
[
i
])
*
weight
;
}
}
vec_t
out_vec
;
#pragma unroll
for
(
uint32_t
i
=
0
;
i
<
vec_size
;
++
i
)
out_vec
[
i
]
=
static_cast
<
scalar_t
>
(
acc
[
i
]);
out_vec
.
cast_store
(
dst_ptr
+
idx
*
vec_size
);
}
}
void
ep_moe_pre_reorder
(
torch
::
Tensor
input
,
torch
::
Tensor
gateup_input
,
...
...
@@ -77,8 +128,8 @@ void ep_moe_pre_reorder(
int64_t
end_expert_id
,
int64_t
topk
,
bool
use_per_token_if_dynamic
)
{
int
total_blocks
=
input
.
size
(
0
);
int
block_size
=
512
;
const
int
total_blocks
=
input
.
size
(
0
);
const
int
block_size
=
512
;
dim3
grid
(
total_blocks
);
dim3
block
(
block_size
);
int
hidden_size
=
input
.
size
(
1
);
...
...
@@ -98,3 +149,33 @@ void ep_moe_pre_reorder(
return
true
;
});
}
void
ep_moe_post_reorder
(
torch
::
Tensor
down_output
,
torch
::
Tensor
output
,
torch
::
Tensor
src2dst
,
torch
::
Tensor
topk_ids
,
torch
::
Tensor
topk_weights
,
int64_t
start_expert_id
,
int64_t
end_expert_id
,
int64_t
topk
)
{
const
int
total_tokens
=
output
.
size
(
0
);
const
int
block_size
=
512
;
dim3
grid
(
total_tokens
);
dim3
block
(
block_size
);
const
int
hidden_size
=
output
.
size
(
1
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
down_output
.
scalar_type
(),
scalar_t
,
[
&
]
{
ep_post_reorder_cuda_kernel
<
scalar_t
><<<
grid
,
block
>>>
(
static_cast
<
scalar_t
*>
(
down_output
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
output
.
data_ptr
()),
src2dst
.
data_ptr
<
int
>
(),
topk_ids
.
data_ptr
<
int
>
(),
static_cast
<
scalar_t
*>
(
topk_weights
.
data_ptr
()),
static_cast
<
int
>
(
start_expert_id
),
static_cast
<
int
>
(
end_expert_id
),
static_cast
<
int
>
(
topk
),
hidden_size
);
return
true
;
});
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
43baba64
...
...
@@ -264,6 +264,16 @@ void ep_moe_pre_reorder(
int64_t
topk
,
bool
use_per_token_if_dynamic
);
void
ep_moe_post_reorder
(
torch
::
Tensor
down_output
,
torch
::
Tensor
output
,
torch
::
Tensor
src2dst
,
torch
::
Tensor
topk_ids
,
torch
::
Tensor
topk_weights
,
int64_t
start_expert_id
,
int64_t
end_expert_id
,
int64_t
topk
);
void
shuffle_rows
(
const
torch
::
Tensor
&
input_tensor
,
const
torch
::
Tensor
&
dst2src_map
,
torch
::
Tensor
&
output_tensor
);
void
cutlass_fp4_group_mm
(
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
43baba64
...
...
@@ -49,6 +49,7 @@ from sgl_kernel.gemm import (
from
sgl_kernel.grammar
import
apply_token_bitmask_inplace_cuda
from
sgl_kernel.moe
import
(
cutlass_fp4_group_mm
,
ep_moe_post_reorder
,
ep_moe_pre_reorder
,
fp8_blockwise_scaled_grouped_mm
,
moe_align_block_size
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
43baba64
...
...
@@ -88,6 +88,28 @@ def ep_moe_pre_reorder(
)
def
ep_moe_post_reorder
(
down_output
,
output
,
src2dst
,
topk_ids
,
topk_weights
,
start_expert_id
,
end_expert_id
,
topk
,
):
return
torch
.
ops
.
sgl_kernel
.
ep_moe_post_reorder
.
default
(
down_output
,
output
,
src2dst
,
topk_ids
,
topk_weights
,
start_expert_id
,
end_expert_id
,
topk
,
)
def
fp8_blockwise_scaled_grouped_mm
(
output
,
a_ptrs
,
...
...
sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py
0 → 100644
View file @
43baba64
import
itertools
import
pytest
import
torch
from
sgl_kernel
import
ep_moe_post_reorder
from
sglang.srt.layers.moe.ep_moe.kernels
import
post_reorder_triton_kernel
def
create_test_tensors
(
batch_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
start_expert_id
:
int
,
end_expert_id
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
down_output
=
torch
.
randn
(
batch_size
*
topk
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
# Ensure src2dst has no duplicate destinations to avoid race conditions
total_tokens
=
batch_size
*
topk
dst_indices
=
torch
.
randperm
(
total_tokens
,
device
=
device
,
dtype
=
torch
.
int32
)
src2dst
=
dst_indices
.
view
(
batch_size
,
topk
)
topk_ids
=
torch
.
randint
(
start_expert_id
,
end_expert_id
+
1
,
(
batch_size
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
,
)
topk_weights
=
torch
.
rand
(
batch_size
,
topk
,
dtype
=
dtype
,
device
=
device
)
return
down_output
,
src2dst
,
topk_ids
,
topk_weights
def
run_cuda_kernel
(
down_output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
src2dst
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
start_expert_id
:
int
,
end_expert_id
:
int
,
topk
:
int
,
):
ep_moe_post_reorder
(
down_output
,
output
,
src2dst
,
topk_ids
,
topk_weights
,
start_expert_id
,
end_expert_id
,
topk
,
)
return
output
def
run_triton_kernel
(
down_output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
src2dst
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
start_expert_id
:
int
,
end_expert_id
:
int
,
topk
:
int
,
hidden_size
:
int
,
):
batch_size
=
down_output
.
size
(
0
)
block_size
=
512
post_reorder_triton_kernel
[(
batch_size
,)](
down_output
,
output
,
src2dst
,
topk_ids
,
topk_weights
,
start_expert_id
,
end_expert_id
,
topk
,
hidden_size
,
block_size
,
)
return
output
def
assert_close
(
a
,
b
):
a32
,
b32
=
a
.
float
(),
b
.
float
()
if
a
.
dtype
is
torch
.
float16
:
torch
.
testing
.
assert_close
(
a32
,
b32
,
rtol
=
1e-5
,
atol
=
1e-2
)
elif
a
.
dtype
is
torch
.
bfloat16
:
torch
.
testing
.
assert_close
(
a32
,
b32
,
rtol
=
1e-4
,
atol
=
1e-1
)
else
:
torch
.
testing
.
assert_close
(
a32
,
b32
,
rtol
=
1e-5
,
atol
=
1e-5
)
@
pytest
.
mark
.
parametrize
(
"batch_size,hidden_size,topk"
,
list
(
itertools
.
product
([
32
,
64
],
[
128
,
256
,
512
],
[
2
,
4
,
8
])),
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
])
def
test_ep_moe_post_reorder_vs_triton
(
batch_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
):
device
=
torch
.
device
(
"cuda"
)
start_expert_id
=
0
end_expert_id
=
15
(
down_output
,
src2dst
,
topk_ids
,
topk_weights
,
)
=
create_test_tensors
(
batch_size
,
hidden_size
,
topk
,
start_expert_id
,
end_expert_id
,
dtype
,
device
,
)
output_cuda
=
torch
.
empty
(
batch_size
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
output_triton
=
torch
.
empty
(
batch_size
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
cuda_output
=
run_cuda_kernel
(
down_output
,
output_cuda
,
src2dst
,
topk_ids
,
topk_weights
,
start_expert_id
,
end_expert_id
,
topk
,
)
triton_output
=
run_triton_kernel
(
down_output
,
output_triton
,
src2dst
,
topk_ids
,
topk_weights
,
start_expert_id
,
end_expert_id
,
topk
,
hidden_size
,
)
assert_close
(
cuda_output
,
triton_output
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment