Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
13c48dcf
Unverified
Commit
13c48dcf
authored
Aug 12, 2025
by
Trevor Morris
Committed by
GitHub
Aug 12, 2025
Browse files
[1/2][resubmit again] sgl-kernel: Fuse routed scaling factor into moe_fused_gate (#9088)
parent
8723b4f1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
11 deletions
+32
-11
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/csrc/moe/moe_fused_gate.cu
sgl-kernel/csrc/moe/moe_fused_gate.cu
+20
-7
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-1
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+9
-2
No files found.
sgl-kernel/csrc/common_extension.cc
View file @
13c48dcf
...
@@ -175,7 +175,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -175,7 +175,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
m
.
def
(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor) -> "
"num_fused_shared_experts, float routed_scaling_factor
, bool apply_routed_scaling_factor_on_output
) -> "
"(Tensor[])"
);
"(Tensor[])"
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
m
.
def
(
m
.
def
(
...
...
sgl-kernel/csrc/moe/moe_fused_gate.cu
View file @
13c48dcf
...
@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
...
@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
,
double
routed_scaling_factor
,
bool
apply_routed_scaling_factor_on_output
,
Params
params
)
{
Params
params
)
{
int
tidx
=
threadIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
int64_t
thread_row
=
int64_t
thread_row
=
...
@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
...
@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
for
(
int
ii
=
0
;
ii
<
topk
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
topk
;
++
ii
)
{
int64_t
const
idx
=
topk
*
thread_row
+
ii
;
int64_t
const
idx
=
topk
*
thread_row
+
ii
;
output_ptr
[
idx
]
=
output_ptr
[
idx
]
/
output_sum
;
output_ptr
[
idx
]
=
output_ptr
[
idx
]
/
output_sum
;
if
(
apply_routed_scaling_factor_on_output
)
{
output_ptr
[
idx
]
*=
routed_scaling_factor
;
}
}
}
}
}
}
}
...
@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
...
@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
)
{
double
routed_scaling_factor
,
bool
apply_routed_scaling_factor_on_output
)
{
KernelParams
<
VPT
,
NUM_EXPERTS
,
THREADS_PER_ROW
,
ROWS_PER_WARP
,
ROWS_PER_CTA
,
WARPS_PER_CTA
>
params
;
KernelParams
<
VPT
,
NUM_EXPERTS
,
THREADS_PER_ROW
,
ROWS_PER_WARP
,
ROWS_PER_CTA
,
WARPS_PER_CTA
>
params
;
moe_fused_gate_impl
<
T
>
(
moe_fused_gate_impl
<
T
>
(
input
,
input
,
...
@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
...
@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
,
params
);
params
);
}
}
...
@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
...
@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
topk_group, \
topk_group, \
topk, \
topk, \
num_fused_shared_experts, \
num_fused_shared_experts, \
routed_scaling_factor); \
routed_scaling_factor, \
apply_routed_scaling_factor_on_output); \
dispatched = true; \
dispatched = true; \
} while (0)
} while (0)
...
@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
...
@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
)
{
double
routed_scaling_factor
,
bool
apply_routed_scaling_factor_on_output
)
{
KernelParamsDynamic
params
;
KernelParamsDynamic
params
;
params
.
NUM_EXPERTS
=
num_experts
;
// e.g, for deepseek v3, this is 256
params
.
NUM_EXPERTS
=
num_experts
;
// e.g, for deepseek v3, this is 256
params
.
VPT
=
num_experts
/
num_expert_group
;
// e.g., for deepseek v3, this is 256 / 8 = 32
params
.
VPT
=
num_experts
/
num_expert_group
;
// e.g., for deepseek v3, this is 256 / 8 = 32
...
@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
...
@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
,
params
);
params
);
}
}
...
@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
)
{
double
routed_scaling_factor
,
bool
apply_routed_scaling_factor_on_output
)
{
int64_t
num_rows
=
input
.
size
(
0
);
int64_t
num_rows
=
input
.
size
(
0
);
int32_t
num_experts
=
input
.
size
(
1
);
int32_t
num_experts
=
input
.
size
(
1
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
...
@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group
,
topk_group
,
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
);
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
moe_fused_gate_kernel_dynamic
<
float16_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
moe_fused_gate_kernel_dynamic
<
float16_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
input
.
data_ptr
(),
...
@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group
,
topk_group
,
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
);
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
moe_fused_gate_kernel_dynamic
<
float32_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
moe_fused_gate_kernel_dynamic
<
float32_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
input
.
data_ptr
(),
...
@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group
,
topk_group
,
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
);
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type for moe_fused_gate"
);
TORCH_CHECK
(
false
,
"Unsupported data type for moe_fused_gate"
);
}
}
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
13c48dcf
...
@@ -247,7 +247,8 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -247,7 +247,8 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
);
double
routed_scaling_factor
,
bool
apply_routed_scaling_factor_on_output
);
void
fp8_blockwise_scaled_grouped_mm
(
void
fp8_blockwise_scaled_grouped_mm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
13c48dcf
...
@@ -44,6 +44,7 @@ def moe_fused_gate(
...
@@ -44,6 +44,7 @@ def moe_fused_gate(
topk
,
topk
,
num_fused_shared_experts
=
0
,
num_fused_shared_experts
=
0
,
routed_scaling_factor
=
0
,
routed_scaling_factor
=
0
,
apply_routed_scaling_factor_on_output
=
False
,
):
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
...
@@ -51,8 +52,13 @@ def moe_fused_gate(
...
@@ -51,8 +52,13 @@ def moe_fused_gate(
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
# num_fused_shared_experts: if > 0, the last several experts will be
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
# replaced with shared experts. the shared experts will be divided by the
# routed_scaling_factor - this is intended to cancel out later when routed+shared
# output is scaled so that shared experts are not scaled.
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
# apply_routed_scaling_factor_on_output: if true, output will be
# scaled by the routed_scaling_factor
return
torch
.
ops
.
sgl_kernel
.
moe_fused_gate
.
default
(
return
torch
.
ops
.
sgl_kernel
.
moe_fused_gate
.
default
(
input_tensor
,
input_tensor
,
bias
,
bias
,
...
@@ -61,6 +67,7 @@ def moe_fused_gate(
...
@@ -61,6 +67,7 @@ def moe_fused_gate(
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment