Unverified Commit 13c48dcf authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[1/2][resubmit again] sgl-kernel: Fuse routed scaling factor into moe_fused_gate (#9088)

parent 8723b4f1
...@@ -175,7 +175,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -175,7 +175,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def( m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor) -> " "num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
"(Tensor[])"); "(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def( m.def(
......
...@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl( ...@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor, double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output,
Params params) { Params params) {
int tidx = threadIdx.x; int tidx = threadIdx.x;
int64_t thread_row = int64_t thread_row =
...@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl( ...@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
for (int ii = 0; ii < topk; ++ii) { for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii; int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = output_ptr[idx] / output_sum; output_ptr[idx] = output_ptr[idx] / output_sum;
if (apply_routed_scaling_factor_on_output) {
output_ptr[idx] *= routed_scaling_factor;
}
} }
} }
} }
...@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel( ...@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor) { double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params; KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>( moe_fused_gate_impl<T>(
input, input,
...@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel( ...@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params); params);
} }
...@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel( ...@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
topk_group, \ topk_group, \
topk, \ topk, \
num_fused_shared_experts, \ num_fused_shared_experts, \
routed_scaling_factor); \ routed_scaling_factor, \
apply_routed_scaling_factor_on_output); \
dispatched = true; \ dispatched = true; \
} while (0) } while (0)
...@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic( ...@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor) { double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
KernelParamsDynamic params; KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32 params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
...@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic( ...@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params); params);
} }
...@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor) { double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output) {
int64_t num_rows = input.size(0); int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1); int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
...@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group, topk_group,
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor); routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kHalf) { } else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>( moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(), input.data_ptr(),
...@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group, topk_group,
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor); routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kFloat) { } else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>( moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(), input.data_ptr(),
...@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group, topk_group,
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor); routed_scaling_factor,
apply_routed_scaling_factor_on_output);
} else { } else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
} }
......
...@@ -247,7 +247,8 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -247,7 +247,8 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor); double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output);
void fp8_blockwise_scaled_grouped_mm( void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output, torch::Tensor& output,
......
...@@ -44,6 +44,7 @@ def moe_fused_gate( ...@@ -44,6 +44,7 @@ def moe_fused_gate(
topk, topk,
num_fused_shared_experts=0, num_fused_shared_experts=0,
routed_scaling_factor=0, routed_scaling_factor=0,
apply_routed_scaling_factor_on_output=False,
): ):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group # it split group of expert into num_expert_group, and use top2 expert weight sum in each group
...@@ -51,8 +52,13 @@ def moe_fused_gate( ...@@ -51,8 +52,13 @@ def moe_fused_gate(
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now. # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk # for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts # num_fused_shared_experts: if > 0, the last several experts will be
# routed_scaling_factor: if > 0, the shared experts will be scaled by this factor # replaced with shared experts. the shared experts will be divided by the
# routed_scaling_factor - this is intended to cancel out later when routed+shared
# output is scaled so that shared experts are not scaled.
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
# apply_routed_scaling_factor_on_output: if true, output will be
# scaled by the routed_scaling_factor
return torch.ops.sgl_kernel.moe_fused_gate.default( return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor, input_tensor,
bias, bias,
...@@ -61,6 +67,7 @@ def moe_fused_gate( ...@@ -61,6 +67,7 @@ def moe_fused_gate(
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
apply_routed_scaling_factor_on_output,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment