Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
45dcfc2e
"graphbolt/src/vscode:/vscode.git/clone" did not exist on "314cedc1b1c3c5ffd2ee0a980010b62faf120f1f"
Unverified
Commit
45dcfc2e
authored
Mar 29, 2025
by
Qingquan Song
Committed by
GitHub
Mar 29, 2025
Browse files
Add deepseek style fused moe group gate selection kernel (#4530)
parent
ddf8981d
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
616 additions
and
1 deletion
+616
-1
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/benchmark/bench_moe_fused_gate.py
sgl-kernel/benchmark/bench_moe_fused_gate.py
+74
-0
sgl-kernel/csrc/moe/moe_fused_gate.cu
sgl-kernel/csrc/moe/moe_fused_gate.cu
+447
-0
sgl-kernel/csrc/torch_extension.cc
sgl-kernel/csrc/torch_extension.cc
+5
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+3
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-1
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+12
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/tests/test_moe_fused_gate.py
sgl-kernel/tests/test_moe_fused_gate.py
+72
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
45dcfc2e
...
...
@@ -151,6 +151,7 @@ set(SOURCES
"csrc/gemm/per_token_group_quant_8bit.cu"
"csrc/gemm/per_token_quant_fp8.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
...
...
sgl-kernel/benchmark/bench_moe_fused_gate.py
0 → 100644
View file @
45dcfc2e
import
itertools
import
math
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
moe_fused_gate
from
sglang.srt.layers.moe.topk
import
biased_grouped_topk
def
biased_grouped_topk_org
(
scores
,
bias
,
num_expert_group
,
topk_group
,
topk
):
return
biased_grouped_topk
(
scores
,
scores
,
bias
,
topk
=
topk
,
renormalize
=
True
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
)
def
biased_grouped_topk_org_kernel
(
scores
,
bias
,
num_expert_group
,
topk_group
,
topk
):
return
moe_fused_gate
(
scores
,
bias
,
num_expert_group
,
topk_group
,
topk
)
seq_length_range
=
[
5000
,
10000
,
15000
,
20000
,
25000
,
30000
,
35000
,
40000
]
configs
=
[(
sq
,)
for
sq
in
seq_length_range
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"seq_length"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"original"
,
"kernel"
],
line_names
=
[
"Original"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"moe-fused-gate-performance"
,
args
=
{},
)
)
def
benchmark
(
seq_length
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
num_experts
,
num_expert_group
,
topk_group
,
topk
=
256
,
8
,
4
,
8
scores
=
torch
.
randn
((
seq_length
,
num_experts
),
device
=
device
,
dtype
=
dtype
)
bias
=
torch
.
rand
(
num_experts
,
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"original"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
biased_grouped_topk_org
(
scores
.
clone
(),
bias
.
clone
(),
num_expert_group
,
topk_group
,
topk
),
quantiles
=
quantiles
,
)
elif
provider
==
"kernel"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
biased_grouped_topk_org_kernel
(
scores
.
clone
(),
bias
.
clone
(),
num_expert_group
,
topk_group
,
topk
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/moe/moe_fused_gate.cu
0 → 100644
View file @
45dcfc2e
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <stdio.h>
#include <torch/all.h>
#include <cfloat>
#include <type_traits>
template
<
typename
T
,
int
N
>
using
AlignedArray
=
cutlass
::
AlignedArray
<
T
,
N
>
;
using
bfloat16_t
=
cutlass
::
bfloat16_t
;
using
float16_t
=
cutlass
::
half_t
;
using
float32_t
=
float
;
// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" matches these operands: built-in operator
// "arithmetic > arithmetic" function "operator>(const __half &, const __half &)"
template
<
typename
T
>
__device__
inline
bool
cmp_gt
(
const
T
&
a
,
const
T
&
b
)
{
if
constexpr
(
std
::
is_same
<
T
,
at
::
Half
>::
value
)
{
// at::Half (or float16_t in our native case) causes ambiguity, so we cast to float.
return
static_cast
<
float
>
(
a
)
>
static_cast
<
float
>
(
b
);
}
else
{
// For types like float, at::BFloat16, or cutlass::half_t / cutlass::bfloat16_t, assume operator> works as expected.
return
a
>
b
;
}
}
template
<
typename
T
>
__device__
inline
bool
cmp_eq
(
const
T
&
a
,
const
T
&
b
)
{
if
constexpr
(
std
::
is_same
<
T
,
at
::
Half
>::
value
)
{
return
static_cast
<
float
>
(
a
)
==
static_cast
<
float
>
(
b
);
}
else
{
return
a
==
b
;
}
}
// Fixed constants common to both dynamic and static template versions:
static
constexpr
int
WARP_SIZE
=
32
;
static
constexpr
int
WARPS_PER_CTA
=
6
;
static
constexpr
int
MAX_VPT
=
32
;
// maximum VPT we support, > params.VPT = num_expert / num_expert_group
// Create an alias for Array using AlignedArray
template
<
typename
T
,
int
N
>
using
Array
=
AlignedArray
<
T
,
N
>
;
// QQ: NOTE expression must have a constant value, this has to be > params.VPT
template
<
typename
T
>
using
AccessType
=
AlignedArray
<
T
,
MAX_VPT
>
;
template
<
typename
T
,
typename
Params
>
__device__
void
moe_fused_gate_impl
(
void
*
input
,
void
*
bias
,
float
*
output_ptr
,
int32_t
*
indices_ptr
,
int64_t
num_rows
,
int64_t
topk_group
,
int64_t
topk
,
Params
params
)
{
int
tidx
=
threadIdx
.
x
;
int64_t
thread_row
=
blockIdx
.
x
*
params
.
ROWS_PER_CTA
+
threadIdx
.
y
*
params
.
ROWS_PER_WARP
+
tidx
/
params
.
THREADS_PER_ROW
;
if
(
thread_row
>=
num_rows
)
{
return
;
}
// Cast pointers to type T:
auto
*
input_ptr
=
reinterpret_cast
<
T
*>
(
input
);
auto
*
bias_ptr
=
reinterpret_cast
<
T
*>
(
bias
);
auto
*
thread_row_ptr
=
input_ptr
+
thread_row
*
params
.
NUM_EXPERTS
;
int
thread_group_idx
=
tidx
%
params
.
THREADS_PER_ROW
;
int
first_elt_read_by_thread
=
thread_group_idx
*
params
.
VPT
;
// Create local arrays for the row chunk and bias chunk and then reinterpret the address of row_chunk as a pointer to
// AccessType.
T
*
thread_read_ptr
=
thread_row_ptr
+
first_elt_read_by_thread
;
Array
<
T
,
MAX_VPT
>
row_chunk
;
AccessType
<
T
>
const
*
vec_thread_read_ptr
=
reinterpret_cast
<
AccessType
<
T
>
const
*>
(
thread_read_ptr
);
T
*
bias_thread_read_ptr
=
bias_ptr
+
first_elt_read_by_thread
;
Array
<
T
,
MAX_VPT
>
bias_chunk
;
AccessType
<
T
>
const
*
vec_bias_thread_read_ptr
=
reinterpret_cast
<
AccessType
<
T
>
const
*>
(
bias_thread_read_ptr
);
// QQ NOTE: doing the follow will be slower than loop assign and more importantly
// have misaligned address issue when params.VPT < 8 and mismatch with MAX_VPT
// AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
// row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
vec_thread_read_ptr
[
0
][
ii
];
bias_chunk
[
ii
]
=
vec_bias_thread_read_ptr
[
0
][
ii
];
}
__syncthreads
();
////////////////////// Sigmoid //////////////////////
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
+
expf
(
-
float
(
row_chunk
[
ii
]))));
}
__syncthreads
();
////////////////////// Add Bias //////////////////////
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
bias_chunk
[
ii
]
=
row_chunk
[
ii
]
+
bias_chunk
[
ii
];
}
////////////////////// Exclude Groups //////////////////////
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
params
.
THREADS_PER_ROW
-
topk_group
;
++
k_idx
)
{
// QQ NOTE Here params.THREADS_PER_ROW = num_expert_group
int
expert
=
first_elt_read_by_thread
;
// local argmax
T
max_val
=
static_cast
<
T
>
(
-
FLT_MAX
);
T
max_val_second
=
static_cast
<
T
>
(
-
FLT_MAX
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
T
val
=
bias_chunk
[
ii
];
if
(
cmp_gt
(
val
,
max_val
))
{
max_val_second
=
max_val
;
max_val
=
val
;
}
else
if
(
cmp_gt
(
val
,
max_val_second
))
{
max_val_second
=
val
;
}
}
// QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert group and sum them as the group weight
// to select expert groups
T
max_sum
=
max_val
+
max_val_second
;
// argmin reduce
#pragma unroll
for
(
int
mask
=
params
.
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
T
other_max_sum
=
static_cast
<
T
>
(
__shfl_xor_sync
(
0xFFFFFFFF
,
static_cast
<
float
>
(
max_sum
),
mask
,
params
.
THREADS_PER_ROW
));
int
other_expert
=
__shfl_xor_sync
(
0xFFFFFFFF
,
expert
,
mask
,
params
.
THREADS_PER_ROW
);
// higher indices win
if
(
cmp_gt
(
max_sum
,
other_max_sum
)
||
(
cmp_eq
(
other_max_sum
,
max_sum
)
&&
other_expert
>
expert
))
{
max_sum
=
other_max_sum
;
expert
=
other_expert
;
}
}
// clear the max value in the thread
if
(
k_idx
<
params
.
THREADS_PER_ROW
-
topk_group
)
{
int
const
thread_to_clear_in_group
=
expert
/
params
.
VPT
;
if
(
thread_group_idx
==
thread_to_clear_in_group
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
params
.
VPT
;
++
ii
)
{
bias_chunk
[
ii
]
=
static_cast
<
T
>
(
FLT_MAX
);
}
}
}
}
__syncthreads
();
////////////////////// Topk //////////////////////
float
output_sum
=
0.0
f
;
for
(
int
k_idx
=
0
;
k_idx
<
topk
;
++
k_idx
)
{
// local argmax
T
max_val
=
bias_chunk
[
0
];
int
expert
=
first_elt_read_by_thread
;
if
(
!
cmp_eq
(
max_val
,
static_cast
<
T
>
(
FLT_MAX
)))
{
#pragma unroll
for
(
int
ii
=
1
;
ii
<
params
.
VPT
;
++
ii
)
{
T
val
=
bias_chunk
[
ii
];
if
(
cmp_gt
(
val
,
max_val
))
{
max_val
=
val
;
expert
=
first_elt_read_by_thread
+
ii
;
}
}
}
else
{
max_val
=
static_cast
<
T
>
(
-
FLT_MAX
);
}
// argmax reduce
#pragma unroll
for
(
int
mask
=
params
.
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
T
other_max
=
static_cast
<
T
>
(
__shfl_xor_sync
(
0xFFFFFFFF
,
static_cast
<
float
>
(
max_val
),
mask
,
params
.
THREADS_PER_ROW
));
int
other_expert
=
__shfl_xor_sync
(
0xFFFFFFFF
,
expert
,
mask
,
params
.
THREADS_PER_ROW
);
// lower indices to win
if
(
cmp_gt
(
other_max
,
max_val
)
||
(
cmp_eq
(
other_max
,
max_val
)
&&
other_expert
<
expert
))
{
max_val
=
other_max
;
expert
=
other_expert
;
}
}
if
(
k_idx
<
topk
)
{
int
thread_to_clear_in_group
=
expert
/
params
.
VPT
;
int64_t
idx
=
topk
*
thread_row
+
k_idx
;
if
(
thread_group_idx
==
thread_to_clear_in_group
)
{
int
expert_to_clear_in_thread
=
expert
%
params
.
VPT
;
// clear the max value in the thread
bias_chunk
[
expert_to_clear_in_thread
]
=
static_cast
<
T
>
(
-
FLT_MAX
);
// store output
output_ptr
[
idx
]
=
static_cast
<
float
>
(
row_chunk
[
expert_to_clear_in_thread
]);
indices_ptr
[
idx
]
=
static_cast
<
int32_t
>
(
expert
);
}
// accumulate sum
if
(
thread_group_idx
==
0
)
{
output_sum
+=
output_ptr
[
idx
];
}
}
__syncthreads
();
}
////////////////////// Rescale Output //////////////////////
if
(
thread_group_idx
==
0
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
topk
;
++
ii
)
{
int64_t
const
idx
=
topk
*
thread_row
+
ii
;
output_ptr
[
idx
]
=
static_cast
<
float
>
(
static_cast
<
T
>
(
output_ptr
[
idx
])
/
static_cast
<
T
>
(
output_sum
));
}
}
}
//------------------------------------------------------------------------------
// Templated Kernel Version (using compile-time constants)
//------------------------------------------------------------------------------
template
<
int
VPT_
,
int
NUM_EXPERTS_
,
int
THREADS_PER_ROW_
,
int
ROWS_PER_WARP_
,
int
ROWS_PER_CTA_
,
int
WARPS_PER_CTA_
>
struct
KernelParams
{
static
constexpr
int
VPT
=
VPT_
;
static
constexpr
int
NUM_EXPERTS
=
NUM_EXPERTS_
;
static
constexpr
int
THREADS_PER_ROW
=
THREADS_PER_ROW_
;
static
constexpr
int
ROWS_PER_WARP
=
ROWS_PER_WARP_
;
static
constexpr
int
ROWS_PER_CTA
=
ROWS_PER_CTA_
;
static
constexpr
int
WARPS_PER_CTA
=
WARPS_PER_CTA_
;
};
template
<
typename
T
,
int
VPT
,
int
NUM_EXPERTS
,
int
THREADS_PER_ROW
,
int
ROWS_PER_WARP
,
int
ROWS_PER_CTA
,
int
WARPS_PER_CTA
>
__global__
void
moe_fused_gate_kernel
(
void
*
input
,
void
*
bias
,
float
*
output_ptr
,
int32_t
*
indices_ptr
,
int64_t
num_rows
,
int64_t
topk_group
,
int64_t
topk
)
{
KernelParams
<
VPT
,
NUM_EXPERTS
,
THREADS_PER_ROW
,
ROWS_PER_WARP
,
ROWS_PER_CTA
,
WARPS_PER_CTA
>
params
;
moe_fused_gate_impl
<
T
>
(
input
,
bias
,
output_ptr
,
indices_ptr
,
num_rows
,
topk_group
,
topk
,
params
);
}
// Macro to compute compile-time constants and launch the kernel.
#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \
do { \
constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \
/* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */
\
constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
moe_fused_gate_kernel<T, VPT, (EXPERTS), (EXPERT_GROUP), ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> \
<<<num_blocks, block_dim, 0, stream>>>( \
input.data_ptr(), \
bias.data_ptr(), \
output.data_ptr<float>(), \
indices.data_ptr<int32_t>(), \
num_rows, \
topk_group, \
topk); \
dispatched = true; \
} while (0)
//------------------------------------------------------------------------------
// Dynamic Kernel Version (parameters computed at runtime)
//------------------------------------------------------------------------------
struct
KernelParamsDynamic
{
int
VPT
;
int
NUM_EXPERTS
;
int
THREADS_PER_ROW
;
int
ROWS_PER_WARP
;
int
ROWS_PER_CTA
;
int
WARPS_PER_CTA
;
};
template
<
typename
T
>
__global__
void
moe_fused_gate_kernel_dynamic
(
void
*
input
,
void
*
bias
,
float
*
output_ptr
,
int32_t
*
indices_ptr
,
int64_t
num_rows
,
int64_t
num_experts
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
)
{
KernelParamsDynamic
params
;
params
.
NUM_EXPERTS
=
num_experts
;
// e.g, for deepseek v3, this is 256
params
.
VPT
=
num_experts
/
num_expert_group
;
// e.g., for deepseek v3, this is 256 / 8 = 32
params
.
THREADS_PER_ROW
=
num_expert_group
;
// fixed as num_expert_group, e.g., for deepseek v3, this is 8
params
.
WARPS_PER_CTA
=
WARPS_PER_CTA
;
// fixed as 6
params
.
ROWS_PER_WARP
=
std
::
max
<
int64_t
>
(
1
,
WARP_SIZE
/
num_expert_group
);
// WARP_SIZE is fixed as 32
params
.
ROWS_PER_CTA
=
params
.
WARPS_PER_CTA
*
params
.
ROWS_PER_WARP
;
moe_fused_gate_impl
<
T
>
(
input
,
bias
,
output_ptr
,
indices_ptr
,
num_rows
,
topk_group
,
topk
,
params
);
}
//------------------------------------------------------------------------------
// Host Launcher Function
//------------------------------------------------------------------------------
std
::
vector
<
at
::
Tensor
>
moe_fused_gate
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
)
{
int64_t
num_rows
=
input
.
size
(
0
);
int32_t
num_experts
=
input
.
size
(
1
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
auto
output
=
torch
::
empty
({
num_rows
,
topk
},
options
);
auto
indices
=
torch
::
empty
({
num_rows
,
topk
},
options
.
dtype
(
torch
::
kInt32
));
// Compute grid dimensions based on runtime value for num_expert_group.
int64_t
rows_per_warp
=
std
::
max
<
int64_t
>
(
1
,
WARP_SIZE
/
num_expert_group
);
int64_t
num_warps
=
(
num_rows
+
rows_per_warp
-
1
)
/
rows_per_warp
;
int64_t
num_blocks
=
(
num_warps
+
WARPS_PER_CTA
-
1
)
/
WARPS_PER_CTA
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
block_dim
(
WARP_SIZE
,
WARPS_PER_CTA
);
// Check 1: Ensure that num_experts is a power of 2.
TORCH_CHECK
((
num_experts
&
(
num_experts
-
1
))
==
0
,
"num_experts must be a power of 2, but got "
,
num_experts
);
// Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2)
TORCH_CHECK
(
num_experts
%
num_expert_group
==
0
,
"num_experts must be divisible by num_expert_group, but got "
,
num_experts
,
" / "
,
num_expert_group
);
int
computed_vpt
=
num_experts
/
num_expert_group
;
// Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per
// threads we can process.
TORCH_CHECK
(
computed_vpt
<=
MAX_VPT
,
"Per group experts: num_experts / num_expert_group = ("
,
computed_vpt
,
") exceeds the maximum supported ("
,
MAX_VPT
,
")"
);
// Dispatch to templated kernel for known compile-time configurations.
// We currently only support for:
// Case 1: 256 experts, with 8 or 16 groups.
// Case 2: 128 experts, with 4 or 8 groups.
// Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32
bool
dispatched
=
false
;
switch
(
num_experts
)
{
case
256
:
if
(
num_expert_group
==
8
)
// This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
LAUNCH_MOE_GATE_CONFIG
(
bfloat16_t
,
256
,
8
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
LAUNCH_MOE_GATE_CONFIG
(
float16_t
,
256
,
8
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
LAUNCH_MOE_GATE_CONFIG
(
float32_t
,
256
,
8
);
}
else
if
(
num_expert_group
==
16
)
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
LAUNCH_MOE_GATE_CONFIG
(
bfloat16_t
,
256
,
16
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
LAUNCH_MOE_GATE_CONFIG
(
float16_t
,
256
,
16
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
LAUNCH_MOE_GATE_CONFIG
(
float32_t
,
256
,
16
);
}
break
;
case
128
:
if
(
num_expert_group
==
4
)
// VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
LAUNCH_MOE_GATE_CONFIG
(
bfloat16_t
,
128
,
4
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
LAUNCH_MOE_GATE_CONFIG
(
float16_t
,
128
,
4
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
LAUNCH_MOE_GATE_CONFIG
(
float32_t
,
128
,
4
);
}
else
if
(
num_expert_group
==
8
)
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
LAUNCH_MOE_GATE_CONFIG
(
bfloat16_t
,
128
,
8
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
LAUNCH_MOE_GATE_CONFIG
(
float16_t
,
128
,
8
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
LAUNCH_MOE_GATE_CONFIG
(
float32_t
,
128
,
8
);
}
break
;
default:
break
;
}
if
(
!
dispatched
)
{
// Fallback to the dynamic kernel if none of the supported combinations match.
// currently only support num_experts / num_expert_group <= 32 for dynamic kernels
if
(
input
.
scalar_type
()
==
at
::
kBFloat16
)
{
moe_fused_gate_kernel_dynamic
<
bfloat16_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
bias
.
data_ptr
(),
output
.
data_ptr
<
float
>
(),
indices
.
data_ptr
<
int32_t
>
(),
num_rows
,
num_experts
,
num_expert_group
,
topk_group
,
topk
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
moe_fused_gate_kernel_dynamic
<
float16_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
bias
.
data_ptr
(),
output
.
data_ptr
<
float
>
(),
indices
.
data_ptr
<
int32_t
>
(),
num_rows
,
num_experts
,
num_expert_group
,
topk_group
,
topk
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
moe_fused_gate_kernel_dynamic
<
float32_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
bias
.
data_ptr
(),
output
.
data_ptr
<
float
>
(),
indices
.
data_ptr
<
int32_t
>
(),
num_rows
,
num_experts
,
num_expert_group
,
topk_group
,
topk
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type for moe_fused_gate"
);
}
}
return
{
output
,
indices
};
}
sgl-kernel/csrc/torch_extension.cc
View file @
45dcfc2e
...
...
@@ -138,6 +138,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
def
(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
"(Tensor[])"
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
/*
* From csrc/speculative
*/
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
45dcfc2e
...
...
@@ -199,6 +199,9 @@ void topk_softmax(
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
std
::
vector
<
at
::
Tensor
>
moe_fused_gate
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
,
int64_t
num_expert_group
,
int64_t
topk_group
,
int64_t
topk
);
/*
* From csrc/speculative
*/
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
45dcfc2e
...
...
@@ -36,7 +36,7 @@ from sgl_kernel.gemm import (
sgl_per_token_group_quant_int8
,
sgl_per_token_quant_fp8
,
)
from
sgl_kernel.moe
import
moe_align_block_size
,
topk_softmax
from
sgl_kernel.moe
import
moe_align_block_size
,
moe_fused_gate
,
topk_softmax
from
sgl_kernel.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
45dcfc2e
...
...
@@ -32,3 +32,15 @@ def topk_softmax(
torch
.
ops
.
sgl_kernel
.
topk_softmax
.
default
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output
)
def
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
return
torch
.
ops
.
sgl_kernel
.
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
)
sgl-kernel/setup.py
View file @
45dcfc2e
...
...
@@ -161,6 +161,7 @@ sources = [
"csrc/gemm/per_token_quant_fp8.cu"
,
"csrc/gemm/per_tensor_quant_fp8.cu"
,
"csrc/moe/moe_align_kernel.cu"
,
"csrc/moe/moe_fused_gate.cu"
,
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/speculative_sampling.cu"
,
...
...
sgl-kernel/tests/test_moe_fused_gate.py
0 → 100644
View file @
45dcfc2e
import
pytest
import
torch
from
sgl_kernel
import
moe_fused_gate
from
sglang.srt.layers.moe.topk
import
biased_grouped_topk
@
pytest
.
mark
.
parametrize
(
"seq_length"
,
list
(
range
(
1
,
10
))
+
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
,
65536
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
float32
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"params"
,
[
(
128
,
4
,
2
,
4
),
(
256
,
8
,
4
,
8
),
# deepseek v3
(
512
,
16
,
8
,
16
),
],
)
def
test_moe_fused_gate_combined
(
seq_length
,
dtype
,
params
):
num_experts
,
num_expert_group
,
topk_group
,
topk
=
params
torch
.
manual_seed
(
seq_length
)
tensor
=
torch
.
rand
((
seq_length
,
num_experts
)).
to
(
dtype
).
cuda
()
scores
=
tensor
.
clone
()
bias
=
torch
.
rand
(
num_experts
).
to
(
dtype
).
cuda
()
output
,
indices
=
moe_fused_gate
(
tensor
,
bias
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
topk
=
topk
,
)
ref_output
,
ref_indices
=
biased_grouped_topk
(
scores
,
scores
,
bias
,
topk
=
topk
,
renormalize
=
True
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
compiled
=
False
,
)
idx_check
=
torch
.
allclose
(
ref_indices
.
sort
()[
0
].
to
(
torch
.
int32
),
indices
.
sort
()[
0
].
to
(
torch
.
int32
),
rtol
=
1e-04
,
atol
=
1e-05
,
)
output_check
=
torch
.
allclose
(
ref_output
.
sort
()[
0
].
to
(
torch
.
float32
),
output
.
sort
()[
0
].
to
(
torch
.
float32
),
rtol
=
1e-04
,
atol
=
1e-05
,
)
assert
idx_check
,
(
f
"Indices mismatch at seq_length
{
seq_length
}
, dtype
{
dtype
}
, "
f
"params
{
params
}
"
)
assert
output_check
,
(
f
"Output mismatch at seq_length
{
seq_length
}
, dtype
{
dtype
}
, "
f
"params
{
params
}
"
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment