Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dd949ace
Unverified
Commit
dd949ace
authored
Aug 10, 2025
by
Yineng Zhang
Committed by
GitHub
Aug 10, 2025
Browse files
Revert "[1/2][resubmit] sgl-kernel: Fuse routed scaling factor into m… (#9035)
parent
f2887498
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
12 additions
and
62 deletions
+12
-62
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+0
-24
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/csrc/moe/moe_fused_gate.cu
sgl-kernel/csrc/moe/moe_fused_gate.cu
+7
-20
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+1
-2
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+2
-9
sgl-kernel/tests/test_moe_fused_gate.py
sgl-kernel/tests/test_moe_fused_gate.py
+1
-6
No files found.
python/sglang/srt/layers/moe/topk.py
View file @
dd949ace
...
@@ -132,7 +132,6 @@ class TopK(CustomOp):
...
@@ -132,7 +132,6 @@ class TopK(CustomOp):
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
# see https://github.com/sgl-project/sglang/pull/4505 for more details
...
@@ -148,9 +147,6 @@ class TopK(CustomOp):
...
@@ -148,9 +147,6 @@ class TopK(CustomOp):
self
.
custom_routing_function
=
custom_routing_function
self
.
custom_routing_function
=
custom_routing_function
self
.
correction_bias
=
correction_bias
self
.
correction_bias
=
correction_bias
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
apply_routed_scaling_factor_on_output
=
(
apply_routed_scaling_factor_on_output
)
self
.
use_triton_kernels
=
global_server_args_dict
[
"enable_triton_kernel_moe"
]
self
.
use_triton_kernels
=
global_server_args_dict
[
"enable_triton_kernel_moe"
]
...
@@ -211,7 +207,6 @@ class TopK(CustomOp):
...
@@ -211,7 +207,6 @@ class TopK(CustomOp):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
self
.
apply_routed_scaling_factor_on_output
,
)
)
def
forward_cpu
(
def
forward_cpu
(
...
@@ -381,7 +376,6 @@ def grouped_topk_gpu(
...
@@ -381,7 +376,6 @@ def grouped_topk_gpu(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
...
@@ -429,8 +423,6 @@ def grouped_topk_gpu(
...
@@ -429,8 +423,6 @@ def grouped_topk_gpu(
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
)
)
topk_weights
=
topk_weights
/
topk_weights_sum
topk_weights
=
topk_weights
/
topk_weights_sum
if
apply_routed_scaling_factor_on_output
:
topk_weights
*=
routed_scaling_factor
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
...
@@ -479,7 +471,6 @@ def biased_grouped_topk_impl(
...
@@ -479,7 +471,6 @@ def biased_grouped_topk_impl(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
...
@@ -531,8 +522,6 @@ def biased_grouped_topk_impl(
...
@@ -531,8 +522,6 @@ def biased_grouped_topk_impl(
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
)
)
topk_weights
=
topk_weights
/
topk_weights_sum
topk_weights
=
topk_weights
/
topk_weights_sum
if
apply_routed_scaling_factor_on_output
:
topk_weights
*=
routed_scaling_factor
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_weights
,
topk_ids
=
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
topk_ids
=
topk_ids_logical_to_physical
(
topk_ids
,
expert_location_dispatch_info
)
...
@@ -575,10 +564,7 @@ def biased_grouped_topk_gpu(
...
@@ -575,10 +564,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
):
):
# TODO(trevor-m): Remove once sgl-kernel is updated
assert
not
apply_routed_scaling_factor_on_output
assert
(
assert
(
routed_scaling_factor
is
not
None
routed_scaling_factor
is
not
None
),
"routed_scaling_factor is required for biased_grouped_topk"
),
"routed_scaling_factor is required for biased_grouped_topk"
...
@@ -597,8 +583,6 @@ def biased_grouped_topk_gpu(
...
@@ -597,8 +583,6 @@ def biased_grouped_topk_gpu(
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
,
# TODO(trevor-m): Uncomment once sgl-kernel is updated
# apply_routed_scaling_factor_on_output,
)
)
# TODO merge into kernel
# TODO merge into kernel
if
(
expert_location_dispatch_info
is
not
None
)
or
(
if
(
expert_location_dispatch_info
is
not
None
)
or
(
...
@@ -609,7 +593,6 @@ def biased_grouped_topk_gpu(
...
@@ -609,7 +593,6 @@ def biased_grouped_topk_gpu(
)
)
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
elif
_use_aiter
:
elif
_use_aiter
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
token
=
gating_output
.
shape
[
0
]
token
=
gating_output
.
shape
[
0
]
device
=
gating_output
.
device
device
=
gating_output
.
device
assert
(
assert
(
...
@@ -641,7 +624,6 @@ def biased_grouped_topk_gpu(
...
@@ -641,7 +624,6 @@ def biased_grouped_topk_gpu(
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
...
@@ -701,7 +683,6 @@ def select_experts(
...
@@ -701,7 +683,6 @@ def select_experts(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_non_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
expert_location_dispatch_info
:
Optional
[
ExpertLocationDispatchInfo
]
=
None
,
apply_routed_scaling_factor_on_output
:
Optional
[
bool
]
=
False
,
)
->
TopKOutput
:
)
->
TopKOutput
:
router_logits
,
correction_bias
=
(
router_logits
,
correction_bias
=
(
expert_location_dispatch
.
transform_select_experts_inputs
(
expert_location_dispatch
.
transform_select_experts_inputs
(
...
@@ -727,7 +708,6 @@ def select_experts(
...
@@ -727,7 +708,6 @@ def select_experts(
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
else
:
else
:
topk_weights
,
topk_ids
=
biased_grouped_topk
(
topk_weights
,
topk_ids
=
biased_grouped_topk
(
...
@@ -742,14 +722,12 @@ def select_experts(
...
@@ -742,14 +722,12 @@ def select_experts(
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
num_token_non_padded
=
num_token_non_padded
,
num_token_non_padded
=
num_token_non_padded
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
elif
torch_native
and
custom_routing_function
is
None
:
elif
torch_native
and
custom_routing_function
is
None
:
assert
(
assert
(
num_token_non_padded
is
None
num_token_non_padded
is
None
),
"num_token_non_padded is not yet supported in fused_topk_native"
),
"num_token_non_padded is not yet supported in fused_topk_native"
assert
expert_location_dispatch_info
is
None
assert
expert_location_dispatch_info
is
None
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
topk_weights
,
topk_ids
=
fused_topk_native
(
topk_weights
,
topk_ids
=
fused_topk_native
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
gating_output
=
router_logits
,
...
@@ -757,7 +735,6 @@ def select_experts(
...
@@ -757,7 +735,6 @@ def select_experts(
renormalize
=
renormalize
,
renormalize
=
renormalize
,
)
)
elif
custom_routing_function
is
None
:
elif
custom_routing_function
is
None
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
# Qwen3MOE uses fused_topk
# Qwen3MOE uses fused_topk
topk_weights
,
topk_ids
=
fused_topk
(
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -772,7 +749,6 @@ def select_experts(
...
@@ -772,7 +749,6 @@ def select_experts(
num_token_non_padded
is
None
num_token_non_padded
is
None
),
"num_token_non_padded is not yet supported in custom_routing_function"
),
"num_token_non_padded is not yet supported in custom_routing_function"
assert
expert_location_dispatch_info
is
None
assert
expert_location_dispatch_info
is
None
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
topk_weights
,
topk_ids
=
custom_routing_function
(
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
gating_output
=
router_logits
,
...
...
sgl-kernel/csrc/common_extension.cc
View file @
dd949ace
...
@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
m
.
def
(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor
, bool apply_routed_scaling_factor_on_output
) -> "
"num_fused_shared_experts, float routed_scaling_factor) -> "
"(Tensor[])"
);
"(Tensor[])"
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
m
.
impl
(
"moe_fused_gate"
,
torch
::
kCUDA
,
&
moe_fused_gate
);
m
.
def
(
m
.
def
(
...
...
sgl-kernel/csrc/moe/moe_fused_gate.cu
View file @
dd949ace
...
@@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl(
...
@@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl(
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
,
double
routed_scaling_factor
,
bool
apply_routed_scaling_factor_on_output
,
Params
params
)
{
Params
params
)
{
int
tidx
=
threadIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
int64_t
thread_row
=
int64_t
thread_row
=
...
@@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl(
...
@@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl(
for
(
int
ii
=
0
;
ii
<
topk
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
topk
;
++
ii
)
{
int64_t
const
idx
=
topk
*
thread_row
+
ii
;
int64_t
const
idx
=
topk
*
thread_row
+
ii
;
output_ptr
[
idx
]
=
output_ptr
[
idx
]
/
output_sum
;
output_ptr
[
idx
]
=
output_ptr
[
idx
]
/
output_sum
;
if
(
apply_routed_scaling_factor_on_output
)
{
output_ptr
[
idx
]
*=
routed_scaling_factor
;
}
}
}
}
}
}
}
...
@@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel(
...
@@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel(
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
,
double
routed_scaling_factor
)
{
bool
apply_routed_scaling_factor_on_output
)
{
KernelParams
<
VPT
,
NUM_EXPERTS
,
THREADS_PER_ROW
,
ROWS_PER_WARP
,
ROWS_PER_CTA
,
WARPS_PER_CTA
>
params
;
KernelParams
<
VPT
,
NUM_EXPERTS
,
THREADS_PER_ROW
,
ROWS_PER_WARP
,
ROWS_PER_CTA
,
WARPS_PER_CTA
>
params
;
moe_fused_gate_impl
<
T
>
(
moe_fused_gate_impl
<
T
>
(
input
,
input
,
...
@@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel(
...
@@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel(
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
,
params
);
params
);
}
}
...
@@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel(
...
@@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel(
topk_group, \
topk_group, \
topk, \
topk, \
num_fused_shared_experts, \
num_fused_shared_experts, \
routed_scaling_factor, \
routed_scaling_factor); \
apply_routed_scaling_factor_on_output); \
dispatched = true; \
dispatched = true; \
} while (0)
} while (0)
...
@@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
...
@@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
,
double
routed_scaling_factor
)
{
bool
apply_routed_scaling_factor_on_output
)
{
KernelParamsDynamic
params
;
KernelParamsDynamic
params
;
params
.
NUM_EXPERTS
=
num_experts
;
// e.g, for deepseek v3, this is 256
params
.
NUM_EXPERTS
=
num_experts
;
// e.g, for deepseek v3, this is 256
params
.
VPT
=
num_experts
/
num_expert_group
;
// e.g., for deepseek v3, this is 256 / 8 = 32
params
.
VPT
=
num_experts
/
num_expert_group
;
// e.g., for deepseek v3, this is 256 / 8 = 32
...
@@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic(
...
@@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic(
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
,
params
);
params
);
}
}
...
@@ -383,8 +374,7 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -383,8 +374,7 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
,
double
routed_scaling_factor
)
{
bool
apply_routed_scaling_factor_on_output
)
{
int64_t
num_rows
=
input
.
size
(
0
);
int64_t
num_rows
=
input
.
size
(
0
);
int32_t
num_experts
=
input
.
size
(
1
);
int32_t
num_experts
=
input
.
size
(
1
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
...
@@ -483,8 +473,7 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -483,8 +473,7 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group
,
topk_group
,
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
);
apply_routed_scaling_factor_on_output
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
}
else
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
moe_fused_gate_kernel_dynamic
<
float16_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
moe_fused_gate_kernel_dynamic
<
float16_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
input
.
data_ptr
(),
...
@@ -497,8 +486,7 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -497,8 +486,7 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group
,
topk_group
,
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
);
apply_routed_scaling_factor_on_output
);
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
}
else
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
{
moe_fused_gate_kernel_dynamic
<
float32_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
moe_fused_gate_kernel_dynamic
<
float32_t
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
.
data_ptr
(),
input
.
data_ptr
(),
...
@@ -511,8 +499,7 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -511,8 +499,7 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group
,
topk_group
,
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
);
apply_routed_scaling_factor_on_output
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type for moe_fused_gate"
);
TORCH_CHECK
(
false
,
"Unsupported data type for moe_fused_gate"
);
}
}
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
dd949ace
...
@@ -243,8 +243,7 @@ std::vector<at::Tensor> moe_fused_gate(
...
@@ -243,8 +243,7 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t
topk_group
,
int64_t
topk_group
,
int64_t
topk
,
int64_t
topk
,
int64_t
num_fused_shared_experts
,
int64_t
num_fused_shared_experts
,
double
routed_scaling_factor
,
double
routed_scaling_factor
);
bool
apply_routed_scaling_factor_on_output
);
void
fp8_blockwise_scaled_grouped_mm
(
void
fp8_blockwise_scaled_grouped_mm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output
,
...
...
sgl-kernel/python/sgl_kernel/moe.py
View file @
dd949ace
...
@@ -44,7 +44,6 @@ def moe_fused_gate(
...
@@ -44,7 +44,6 @@ def moe_fused_gate(
topk
,
topk
,
num_fused_shared_experts
=
0
,
num_fused_shared_experts
=
0
,
routed_scaling_factor
=
0
,
routed_scaling_factor
=
0
,
apply_routed_scaling_factor_on_output
=
False
,
):
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
...
@@ -52,13 +51,8 @@ def moe_fused_gate(
...
@@ -52,13 +51,8 @@ def moe_fused_gate(
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
# replaced with shared experts. the shared experts will be divided by the
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
# routed_scaling_factor - this is intended to cancel out later when routed+shared
# output is scaled so that shared experts are not scaled.
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
# apply_routed_scaling_factor_on_output: if true, output will be
# scaled by the routed_scaling_factor
return
torch
.
ops
.
sgl_kernel
.
moe_fused_gate
.
default
(
return
torch
.
ops
.
sgl_kernel
.
moe_fused_gate
.
default
(
input_tensor
,
input_tensor
,
bias
,
bias
,
...
@@ -67,7 +61,6 @@ def moe_fused_gate(
...
@@ -67,7 +61,6 @@ def moe_fused_gate(
topk
,
topk
,
num_fused_shared_experts
,
num_fused_shared_experts
,
routed_scaling_factor
,
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
,
)
)
...
...
sgl-kernel/tests/test_moe_fused_gate.py
View file @
dd949ace
...
@@ -19,10 +19,7 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
...
@@ -19,10 +19,7 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"num_fused_shared_experts"
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"num_fused_shared_experts"
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"apply_routed_scaling_factor_on_output"
,
[
True
,
False
])
def
test_moe_fused_gate_combined
(
seq_length
,
params
,
num_fused_shared_experts
):
def
test_moe_fused_gate_combined
(
seq_length
,
params
,
num_fused_shared_experts
,
apply_routed_scaling_factor_on_output
):
num_experts
,
num_expert_group
,
topk_group
,
topk
=
params
num_experts
,
num_expert_group
,
topk_group
,
topk
=
params
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
@@ -40,7 +37,6 @@ def test_moe_fused_gate_combined(
...
@@ -40,7 +37,6 @@ def test_moe_fused_gate_combined(
topk
=
topk
,
topk
=
topk
,
num_fused_shared_experts
=
num_fused_shared_experts
,
num_fused_shared_experts
=
num_fused_shared_experts
,
routed_scaling_factor
=
2.5
,
routed_scaling_factor
=
2.5
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
ref_output
,
ref_indices
=
biased_grouped_topk
(
ref_output
,
ref_indices
=
biased_grouped_topk
(
scores
,
scores
,
...
@@ -52,7 +48,6 @@ def test_moe_fused_gate_combined(
...
@@ -52,7 +48,6 @@ def test_moe_fused_gate_combined(
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
num_fused_shared_experts
=
num_fused_shared_experts
,
routed_scaling_factor
=
2.5
,
routed_scaling_factor
=
2.5
,
apply_routed_scaling_factor_on_output
=
apply_routed_scaling_factor_on_output
,
)
)
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment