Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6c01844f
Unverified
Commit
6c01844f
authored
Oct 16, 2025
by
Qi Yuhang
Committed by
GitHub
Oct 15, 2025
Browse files
[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)
parent
f226d3da
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
22 additions
and
8 deletions
+22
-8
sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py
sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py
+2
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+2
-1
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu
+4
-1
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh
.../csrc/expert_specialization/es_fp8_blockwise_launcher.cuh
+8
-4
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-1
sgl-kernel/python/sgl_kernel/expert_specialization.py
sgl-kernel/python/sgl_kernel/expert_specialization.py
+2
-0
sgl-kernel/tests/test_es_fp8_blockwise_moe.py
sgl-kernel/tests/test_es_fp8_blockwise_moe.py
+2
-1
No files found.
sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py
View file @
6c01844f
...
...
@@ -133,6 +133,7 @@ def bench_es(
d_strides
=
torch
.
full
(
(
num_groups
,),
c_out
.
stride
(
0
),
device
=
device
,
dtype
=
torch
.
int64
)
workspace
=
torch
.
empty
((
1024
*
1024
*
1024
),
device
=
device
,
dtype
=
torch
.
uint8
)
def
run_cutlass
():
es_fp8_blockwise_scaled_grouped_mm
(
...
...
@@ -146,6 +147,7 @@ def bench_es(
d_strides
,
problem_sizes
,
expert_offsets
[:
-
1
],
workspace
,
)
run_cutlass
()
...
...
sgl-kernel/csrc/common_extension.cc
View file @
6c01844f
...
...
@@ -537,7 +537,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/
m
.
def
(
"es_fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets) -> ()"
);
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
"()"
);
m
.
impl
(
"es_fp8_blockwise_scaled_grouped_mm"
,
&
es_fp8_blockwise_scaled_grouped_mm
);
}
...
...
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu
View file @
6c01844f
...
...
@@ -40,7 +40,8 @@ void es_fp8_blockwise_scaled_grouped_mm(
const
torch
::
Tensor
&
stride_b
,
const
torch
::
Tensor
&
stride_d
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
)
{
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
workspace
)
{
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have shape (num_experts, 3)"
);
...
...
@@ -135,6 +136,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
lm_problem_sizes
,
mm_problem_sizes
,
hm_problem_sizes
,
workspace
,
is_h20_device
,
stream
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
...
...
@@ -152,6 +154,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
lm_problem_sizes
,
mm_problem_sizes
,
hm_problem_sizes
,
workspace
,
is_h20_device
,
stream
);
}
else
{
...
...
sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh
View file @
6c01844f
...
...
@@ -98,6 +98,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
workspace
,
cudaStream_t
stream
)
{
using
ElementA
=
typename
GemmTraits
::
ElementA
;
using
StrideA
=
typename
GemmTraits
::
StrideA
;
...
...
@@ -143,10 +144,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
torch
::
TensorOptions
options_uint8
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
out_ptrs
.
device
());
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
torch
::
Tensor
workspace
=
torch
::
empty
(
workspace_size
,
options_uint8
);
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
...
...
@@ -169,6 +166,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
const
torch
::
Tensor
&
lm_problem_sizes
,
const
torch
::
Tensor
&
mm_problem_sizes
,
const
torch
::
Tensor
&
hm_problem_sizes
,
const
torch
::
Tensor
&
workspace
,
bool
is_h20_device
,
cudaStream_t
stream
)
{
using
LowMGemmH20Traits
=
...
...
@@ -199,6 +197,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb
,
layout_sfa
,
lm_problem_sizes
,
workspace
,
stream
);
}
else
{
launch_sm90_fp8_blockwise_scaled_group_mm
<
LowMGemmH20Traits
>
(
...
...
@@ -213,6 +212,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb
,
layout_sfa
,
lm_problem_sizes
,
workspace
,
stream
);
}
...
...
@@ -229,6 +229,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb
,
layout_sfa
,
mm_problem_sizes
,
workspace
,
stream
);
}
else
{
launch_sm90_fp8_blockwise_scaled_group_mm
<
HighMGemmHx00Traits
>
(
...
...
@@ -243,6 +244,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa
,
layout_sfb
,
mm_problem_sizes
,
workspace
,
stream
);
}
...
...
@@ -259,6 +261,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa
,
layout_sfb
,
hm_problem_sizes
,
workspace
,
stream
);
}
else
{
launch_sm90_fp8_blockwise_scaled_group_mm
<
HighMGemmH20Traits
>
(
...
...
@@ -273,6 +276,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa
,
layout_sfb
,
hm_problem_sizes
,
workspace
,
stream
);
}
}
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
6c01844f
...
...
@@ -835,4 +835,5 @@ void es_fp8_blockwise_scaled_grouped_mm(
const
torch
::
Tensor
&
stride_b
,
const
torch
::
Tensor
&
stride_d
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
);
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
workspace
);
sgl-kernel/python/sgl_kernel/expert_specialization.py
View file @
6c01844f
...
...
@@ -12,6 +12,7 @@ def es_fp8_blockwise_scaled_grouped_mm(
stride_d
,
problem_sizes
,
expert_offsets
,
workspace
,
):
torch
.
ops
.
sgl_kernel
.
es_fp8_blockwise_scaled_grouped_mm
.
default
(
output
,
...
...
@@ -24,4 +25,5 @@ def es_fp8_blockwise_scaled_grouped_mm(
stride_d
,
problem_sizes
,
expert_offsets
,
workspace
,
)
sgl-kernel/tests/test_es_fp8_blockwise_moe.py
View file @
6c01844f
...
...
@@ -168,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
].
t
()
# b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
b_stack
=
b_stack
.
transpose
(
1
,
2
)
# Transpose Matrix B to Column-Major.
b_scale_stack
=
b_scale_stack
.
transpose
(
1
,
2
)
workspace
=
torch
.
empty
((
1024
*
1024
*
1024
),
device
=
device
,
dtype
=
torch
.
uint8
)
c_out
=
torch
.
empty
((
expert_offsets
[
-
1
],
n_g
),
device
=
device
,
dtype
=
out_dtype
)
a_strides
=
torch
.
full
(
(
num_experts
,),
a_stack
.
stride
(
0
),
device
=
device
,
dtype
=
torch
.
int64
...
...
@@ -188,6 +188,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
d_strides
,
problem_sizes
,
expert_offsets
[:
-
1
],
workspace
,
)
for
g
in
range
(
num_experts
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment