Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6bdd2786
Unverified
Commit
6bdd2786
authored
Aug 01, 2025
by
Peter Pan
Committed by
GitHub
Aug 01, 2025
Browse files
[Kimi K2] dsv3_router_gemm supports NUM_EXPERTS == 384 (#8013)
parent
46e9d1c7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
188 additions
and
30 deletions
+188
-30
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
+36
-12
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
+50
-0
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
+50
-16
sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu
sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu
+50
-0
sgl-kernel/tests/test_dsv3_router_gemm.py
sgl-kernel/tests/test_dsv3_router_gemm.py
+2
-2
No files found.
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
View file @
6bdd2786
...
...
@@ -13,9 +13,14 @@ from sgl_kernel import dsv3_router_gemm
x_vals
=
[
i
+
1
for
i
in
range
(
16
)],
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch"
,
"sgl-kernel"
],
line_names
=
[
"torch"
,
"dsv3_router_gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
line_vals
=
[
"torch-256"
,
"sgl-kernel-256"
,
"torch-384"
,
"sgl-kernel-384"
],
line_names
=
[
"torch-256"
,
"dsv3_router_gemm-256"
,
"torch-384"
,
"dsv3_router_gemm-384"
,
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"TFLOPs"
,
plot_name
=
"input-bf16-output-bf16 dsv3 router gemm throughput"
,
args
=
{},
...
...
@@ -23,19 +28,26 @@ from sgl_kernel import dsv3_router_gemm
)
def
benchmark_bf16_output
(
num_tokens
,
impl
):
# M: num_tokens, K: hidden_dim, N: num_experts
M
,
K
,
N
=
num_tokens
,
7168
,
256
M
,
K
=
num_tokens
,
7168
if
impl
==
"torch-256"
or
impl
==
"sgl-kernel-256"
:
N
=
256
elif
impl
==
"torch-384"
or
impl
==
"sgl-kernel-384"
:
N
=
384
else
:
raise
ValueError
(
f
"Unknown impl:
{
impl
}
"
)
mat_a
=
torch
.
randn
((
M
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
N
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
impl
==
"torch"
:
if
impl
==
"torch
-256"
or
impl
==
"torch-384
"
:
def
runner
():
F
.
linear
(
mat_a
,
mat_b
)
elif
impl
==
"sgl-kernel"
:
elif
impl
==
"sgl-kernel
-256"
or
impl
==
"sgl-kernel-384
"
:
def
runner
():
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
bfloat16
)
...
...
@@ -55,9 +67,14 @@ def benchmark_bf16_output(num_tokens, impl):
x_vals
=
[
i
+
1
for
i
in
range
(
16
)],
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch"
,
"sgl-kernel"
],
line_names
=
[
"torch"
,
"dsv3_router_gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
line_vals
=
[
"torch-256"
,
"sgl-kernel-256"
,
"torch-384"
,
"sgl-kernel-384"
],
line_names
=
[
"torch-256"
,
"dsv3_router_gemm-256"
,
"torch-384"
,
"dsv3_router_gemm-384"
,
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"TFLOPs"
,
plot_name
=
"input-bf16-output-fp32 dsv3 router gemm throughput"
,
args
=
{},
...
...
@@ -65,19 +82,26 @@ def benchmark_bf16_output(num_tokens, impl):
)
def
benchmark_float_output
(
num_tokens
,
impl
):
# M: num_tokens, K: hidden_dim, N: num_experts
M
,
K
,
N
=
num_tokens
,
7168
,
256
M
,
K
=
num_tokens
,
7168
if
impl
==
"torch-256"
or
impl
==
"sgl-kernel-256"
:
N
=
256
elif
impl
==
"torch-384"
or
impl
==
"sgl-kernel-384"
:
N
=
384
else
:
raise
ValueError
(
f
"Unknown impl:
{
impl
}
"
)
mat_a
=
torch
.
randn
((
M
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
N
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
impl
==
"torch"
:
if
impl
==
"torch
-256"
or
impl
==
"torch-384
"
:
def
runner
():
F
.
linear
(
mat_a
,
mat_b
).
to
(
torch
.
float32
)
elif
impl
==
"sgl-kernel"
:
elif
impl
==
"sgl-kernel
-256"
or
impl
==
"sgl-kernel-384
"
:
def
runner
():
dsv3_router_gemm
(
mat_a
,
mat_b
,
out_dtype
=
torch
.
float32
)
...
...
sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu
View file @
6bdd2786
...
...
@@ -185,6 +185,7 @@ void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const*
mat_b
);
}
// Template instantiations for DEFAULT_NUM_EXPERTS experts
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
1
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
...
...
@@ -232,3 +233,52 @@ template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>(
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
16
,
256
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
1
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
2
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
3
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
4
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
5
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
6
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
7
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
8
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
9
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
10
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
11
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
12
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
13
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
14
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
15
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmBf16Output
<
__nv_bfloat16
,
16
,
384
,
7168
>(
__nv_bfloat16
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu
View file @
6bdd2786
...
...
@@ -25,6 +25,10 @@
#include "cuda_runtime.h"
#include "utils.h"
static
constexpr
int
DEFAULT_NUM_EXPERTS
=
256
;
static
constexpr
int
KIMI_K2_NUM_EXPERTS
=
384
;
static
constexpr
int
DEFAULT_HIDDEN_DIM
=
7168
;
template
<
typename
T
,
int
kNumTokens
,
int
kNumExperts
,
int
kHiddenDim
>
void
invokeRouterGemmFloatOutput
(
float
*
output
,
T
const
*
mat_a
,
T
const
*
mat_b
,
cudaStream_t
stream
);
...
...
@@ -91,12 +95,24 @@ void dsv3_router_gemm(
TORCH_CHECK
(
output
.
dim
()
==
2
&&
mat_a
.
dim
()
==
2
&&
mat_b
.
dim
()
==
2
);
const
int
num_tokens
=
mat_a
.
size
(
0
);
const
expr
int
num_experts
=
256
;
const
expr
int
hidden_dim
=
7168
;
const
int
num_experts
=
mat_b
.
size
(
0
)
;
const
int
hidden_dim
=
mat_a
.
size
(
1
)
;
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
1
),
"mat_a and mat_b must have the same hidden_dim"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
hidden_dim
,
"currently hidden_dim only supports 7168"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
==
num_experts
,
"currently num_experts only supports 256"
);
TORCH_CHECK
(
hidden_dim
==
DEFAULT_HIDDEN_DIM
,
"Expected hidden_dim="
,
DEFAULT_HIDDEN_DIM
,
", but got hidden_dim="
,
hidden_dim
);
TORCH_CHECK
(
num_experts
==
DEFAULT_NUM_EXPERTS
||
num_experts
==
KIMI_K2_NUM_EXPERTS
,
"Expected num_experts="
,
DEFAULT_NUM_EXPERTS
,
" or num_experts="
,
KIMI_K2_NUM_EXPERTS
,
", but got num_experts="
,
num_experts
);
TORCH_CHECK
(
num_tokens
>=
1
&&
num_tokens
<=
16
,
"currently num_tokens must be less than or equal to 16 for router_gemm"
);
TORCH_CHECK
(
mat_a
.
dtype
()
==
torch
::
kBFloat16
,
"mat_a must be bf16"
);
...
...
@@ -110,18 +126,36 @@ void dsv3_router_gemm(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
LoopUnroller
<
1
,
16
,
num_experts
,
hidden_dim
>::
unroll_float_output
(
num_tokens
,
reinterpret_cast
<
float
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
if
(
num_experts
==
DEFAULT_NUM_EXPERTS
)
{
LoopUnroller
<
1
,
16
,
DEFAULT_NUM_EXPERTS
,
DEFAULT_HIDDEN_DIM
>::
unroll_float_output
(
num_tokens
,
reinterpret_cast
<
float
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
else
if
(
num_experts
==
KIMI_K2_NUM_EXPERTS
)
{
LoopUnroller
<
1
,
16
,
KIMI_K2_NUM_EXPERTS
,
DEFAULT_HIDDEN_DIM
>::
unroll_float_output
(
num_tokens
,
reinterpret_cast
<
float
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
LoopUnroller
<
1
,
16
,
num_experts
,
hidden_dim
>::
unroll_bf16_output
(
num_tokens
,
reinterpret_cast
<
__nv_bfloat16
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
if
(
num_experts
==
DEFAULT_NUM_EXPERTS
)
{
LoopUnroller
<
1
,
16
,
DEFAULT_NUM_EXPERTS
,
DEFAULT_HIDDEN_DIM
>::
unroll_bf16_output
(
num_tokens
,
reinterpret_cast
<
__nv_bfloat16
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
else
if
(
num_experts
==
KIMI_K2_NUM_EXPERTS
)
{
LoopUnroller
<
1
,
16
,
KIMI_K2_NUM_EXPERTS
,
DEFAULT_HIDDEN_DIM
>::
unroll_bf16_output
(
num_tokens
,
reinterpret_cast
<
__nv_bfloat16
*>
(
output
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_a
.
data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
mat_b
.
data_ptr
()),
stream
);
}
}
}
sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu
View file @
6bdd2786
...
...
@@ -184,6 +184,7 @@ void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b,
mat_b
);
}
// Template instantiations for DEFAULT_NUM_EXPERTS experts
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
1
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
...
...
@@ -231,3 +232,52 @@ template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>(
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
16
,
256
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
1
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
2
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
3
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
4
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
5
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
6
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
7
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
8
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
9
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
10
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
11
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
12
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
13
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
14
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
15
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
template
void
invokeRouterGemmFloatOutput
<
__nv_bfloat16
,
16
,
384
,
7168
>(
float
*
,
__nv_bfloat16
const
*
,
__nv_bfloat16
const
*
,
cudaStream_t
);
sgl-kernel/tests/test_dsv3_router_gemm.py
View file @
6bdd2786
...
...
@@ -5,8 +5,8 @@ from sgl_kernel import dsv3_router_gemm
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
i
+
1
for
i
in
range
(
16
)])
def
test_dsv3_router_gemm
(
num_tokens
):
num_experts
=
256
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
256
,
384
])
def
test_dsv3_router_gemm
(
num_tokens
,
num_experts
):
hidden_dim
=
7168
mat_a
=
torch
.
randn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment