"vscode:/vscode.git/clone" did not exist on "d2ef679d42feac9901c347bc24ff0b6ab67a938b"
Unverified Commit 591c232f authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[1/2][resubmit] sgl-kernel: Fuse routed scaling factor into moe_fused_gate (select_experts) (#8770)

parent f352b793
......@@ -132,6 +132,7 @@ class TopK(CustomOp):
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
......@@ -147,6 +148,9 @@ class TopK(CustomOp):
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
self.apply_routed_scaling_factor_on_output = (
apply_routed_scaling_factor_on_output
)
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
......@@ -207,6 +211,7 @@ class TopK(CustomOp):
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=self.apply_routed_scaling_factor_on_output,
)
def forward_cpu(
......@@ -375,6 +380,7 @@ def grouped_topk_gpu(
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
......@@ -422,6 +428,8 @@ def grouped_topk_gpu(
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
......@@ -468,6 +476,7 @@ def biased_grouped_topk_impl(
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
......@@ -519,6 +528,8 @@ def biased_grouped_topk_impl(
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
......@@ -561,7 +572,10 @@ def biased_grouped_topk_gpu(
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
# TODO(trevor-m): Remove once sgl-kernel is updated
assert not apply_routed_scaling_factor_on_output
assert (
routed_scaling_factor is not None
), "routed_scaling_factor is required for biased_grouped_topk"
......@@ -580,6 +594,8 @@ def biased_grouped_topk_gpu(
topk,
num_fused_shared_experts,
routed_scaling_factor,
# TODO(trevor-m): Uncomment once sgl-kernel is updated
# apply_routed_scaling_factor_on_output,
)
# TODO merge into kernel
if (expert_location_dispatch_info is not None) or (
......@@ -590,6 +606,7 @@ def biased_grouped_topk_gpu(
)
return topk_weights, topk_ids
elif _use_aiter:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
token = gating_output.shape[0]
device = gating_output.device
assert (
......@@ -621,6 +638,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
......@@ -680,6 +698,7 @@ def select_experts(
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
) -> TopKOutput:
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
......@@ -705,6 +724,7 @@ def select_experts(
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
else:
topk_weights, topk_ids = biased_grouped_topk(
......@@ -719,12 +739,14 @@ def select_experts(
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
elif torch_native and custom_routing_function is None:
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk_native"
assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
......@@ -732,6 +754,7 @@ def select_experts(
renormalize=renormalize,
)
elif custom_routing_function is None:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
# Qwen3MOE uses fused_topk
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
......@@ -746,6 +769,7 @@ def select_experts(
num_token_non_padded is None
), "num_token_non_padded is not yet supported in custom_routing_function"
assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
......
......@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor) -> "
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def(
......
......@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output,
Params params) {
int tidx = threadIdx.x;
int64_t thread_row =
......@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = output_ptr[idx] / output_sum;
if (apply_routed_scaling_factor_on_output) {
output_ptr[idx] *= routed_scaling_factor;
}
}
}
}
......@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor) {
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(
input,
......@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params);
}
......@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
topk_group, \
topk, \
num_fused_shared_experts, \
routed_scaling_factor); \
routed_scaling_factor, \
apply_routed_scaling_factor_on_output); \
dispatched = true; \
} while (0)
......@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor) {
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
......@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params);
}
......@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor) {
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
......@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor);
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
......@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor);
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
......@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor);
routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
}
......
......@@ -243,7 +243,8 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t topk_group,
int64_t topk,
int64_t num_fused_shared_experts,
double routed_scaling_factor);
double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output);
void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output,
......
......@@ -44,6 +44,7 @@ def moe_fused_gate(
topk,
num_fused_shared_experts=0,
routed_scaling_factor=0,
apply_routed_scaling_factor_on_output=False,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
......@@ -51,8 +52,13 @@ def moe_fused_gate(
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
# num_fused_shared_experts: if > 0, the last several experts will be
# replaced with shared experts. the shared experts will be divided by the
# routed_scaling_factor - this is intended to cancel out later when routed+shared
# output is scaled so that shared experts are not scaled.
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
# apply_routed_scaling_factor_on_output: if true, output will be
# scaled by the routed_scaling_factor
return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor,
bias,
......@@ -61,6 +67,7 @@ def moe_fused_gate(
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
)
......
......@@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
],
)
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False])
def test_moe_fused_gate_combined(
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
):
num_experts, num_expert_group, topk_group, topk = params
dtype = torch.float32
......@@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk=topk,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
ref_output, ref_indices = biased_grouped_topk(
scores,
......@@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk_group=topk_group,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment