Unverified Commit 8e09b370 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Sgl kernel fused_moe_gate support n_shared_experts (#5440)

parent 53dcf388
...@@ -146,7 +146,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -146,7 +146,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def( m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> " "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])"); "(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
......
...@@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl( ...@@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl(
int64_t num_rows, int64_t num_rows,
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor,
Params params) { Params params) {
int tidx = threadIdx.x; int tidx = threadIdx.x;
int64_t thread_row = int64_t thread_row =
...@@ -65,6 +67,9 @@ __device__ void moe_fused_gate_impl( ...@@ -65,6 +67,9 @@ __device__ void moe_fused_gate_impl(
return; return;
} }
// Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0);
// Cast pointers to type T: // Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input); auto* input_ptr = reinterpret_cast<T*>(input);
auto* bias_ptr = reinterpret_cast<T*>(bias); auto* bias_ptr = reinterpret_cast<T*>(bias);
...@@ -163,7 +168,7 @@ __device__ void moe_fused_gate_impl( ...@@ -163,7 +168,7 @@ __device__ void moe_fused_gate_impl(
////////////////////// Topk ////////////////////// ////////////////////// Topk //////////////////////
float output_sum = 0.0f; float output_sum = 0.0f;
for (int k_idx = 0; k_idx < topk; ++k_idx) { for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) {
// local argmax // local argmax
T max_val = bias_chunk[0]; T max_val = bias_chunk[0];
int expert = first_elt_read_by_thread; int expert = first_elt_read_by_thread;
...@@ -181,7 +186,7 @@ __device__ void moe_fused_gate_impl( ...@@ -181,7 +186,7 @@ __device__ void moe_fused_gate_impl(
max_val = static_cast<T>(-FLT_MAX); max_val = static_cast<T>(-FLT_MAX);
} }
// argmax reduce // argmax reduce
#pragma unroll #pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) { for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max = T other_max =
...@@ -195,36 +200,46 @@ __device__ void moe_fused_gate_impl( ...@@ -195,36 +200,46 @@ __device__ void moe_fused_gate_impl(
} }
} }
if (k_idx < topk) { int thread_to_clear_in_group = expert / params.VPT;
int thread_to_clear_in_group = expert / params.VPT; int64_t idx = topk * thread_row + k_idx;
int64_t idx = topk * thread_row + k_idx;
if (thread_group_idx == thread_to_clear_in_group) { if (thread_group_idx == thread_to_clear_in_group) {
int expert_to_clear_in_thread = expert % params.VPT; int expert_to_clear_in_thread = expert % params.VPT;
// clear the max value in the thread // clear the max value in the thread
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX); bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
// store output // store output
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]); output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
indices_ptr[idx] = static_cast<int32_t>(expert); indices_ptr[idx] = static_cast<int32_t>(expert);
} }
// accumulate sum // accumulate sum for all elements
if (thread_group_idx == 0) { if (thread_group_idx == 0) {
output_sum += output_ptr[idx]; output_sum += output_ptr[idx];
}
} }
__syncthreads(); __syncthreads();
} }
if (thread_group_idx == 0 && n_share_experts_fusion > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
// Use round-robin to select expert
int64_t expert_offset = thread_row % n_share_experts_fusion;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
}
__syncthreads();
////////////////////// Rescale Output ////////////////////// ////////////////////// Rescale Output //////////////////////
if (thread_group_idx == 0) { if (thread_group_idx == 0) {
#pragma unroll #pragma unroll
for (int ii = 0; ii < topk; ++ii) { for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii; int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = static_cast<float>(static_cast<T>(output_ptr[idx]) / static_cast<T>(output_sum)); output_ptr[idx] = output_ptr[idx] / output_sum;
} }
} }
} }
...@@ -257,9 +272,21 @@ __global__ void moe_fused_gate_kernel( ...@@ -257,9 +272,21 @@ __global__ void moe_fused_gate_kernel(
int32_t* indices_ptr, int32_t* indices_ptr,
int64_t num_rows, int64_t num_rows,
int64_t topk_group, int64_t topk_group,
int64_t topk) { int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params; KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params); moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
} }
// Macro to compute compile-time constants and launch the kernel. // Macro to compute compile-time constants and launch the kernel.
...@@ -277,7 +304,9 @@ __global__ void moe_fused_gate_kernel( ...@@ -277,7 +304,9 @@ __global__ void moe_fused_gate_kernel(
indices.data_ptr<int32_t>(), \ indices.data_ptr<int32_t>(), \
num_rows, \ num_rows, \
topk_group, \ topk_group, \
topk); \ topk, \
n_share_experts_fusion, \
routed_scaling_factor); \
dispatched = true; \ dispatched = true; \
} while (0) } while (0)
...@@ -303,7 +332,9 @@ __global__ void moe_fused_gate_kernel_dynamic( ...@@ -303,7 +332,9 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t num_experts, int64_t num_experts,
int64_t num_expert_group, int64_t num_expert_group,
int64_t topk_group, int64_t topk_group,
int64_t topk) { int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParamsDynamic params; KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32 params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
...@@ -312,14 +343,30 @@ __global__ void moe_fused_gate_kernel_dynamic( ...@@ -312,14 +343,30 @@ __global__ void moe_fused_gate_kernel_dynamic(
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32 params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP; params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params); moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// Host Launcher Function // Host Launcher Function
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
std::vector<at::Tensor> std::vector<at::Tensor> moe_fused_gate(
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) { at::Tensor& input,
at::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
int64_t num_rows = input.size(0); int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1); int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
...@@ -416,7 +463,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in ...@@ -416,7 +463,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts, num_experts,
num_expert_group, num_expert_group,
topk_group, topk_group,
topk); topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) { } else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>( moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(), input.data_ptr(),
...@@ -427,7 +476,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in ...@@ -427,7 +476,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts, num_experts,
num_expert_group, num_expert_group,
topk_group, topk_group,
topk); topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) { } else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>( moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(), input.data_ptr(),
...@@ -438,7 +489,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in ...@@ -438,7 +489,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts, num_experts,
num_expert_group, num_expert_group,
topk_group, topk_group,
topk); topk,
n_share_experts_fusion,
routed_scaling_factor);
} else { } else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
} }
......
...@@ -200,8 +200,14 @@ void topk_softmax( ...@@ -200,8 +200,14 @@ void topk_softmax(
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output); torch::Tensor& gating_output);
std::vector<at::Tensor> std::vector<at::Tensor> moe_fused_gate(
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk); at::Tensor& input,
at::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor);
/* /*
* From csrc/speculative * From csrc/speculative
......
...@@ -34,13 +34,29 @@ def topk_softmax( ...@@ -34,13 +34,29 @@ def topk_softmax(
) )
def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk): def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion=0,
routed_scaling_factor=0,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group # it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups # as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now. # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk # for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return torch.ops.sgl_kernel.moe_fused_gate.default( return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor, bias, num_expert_group, topk_group, topk input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
) )
...@@ -19,13 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk ...@@ -19,13 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
(512, 16, 8, 16), (512, 16, 8, 16),
], ],
) )
def test_moe_fused_gate_combined(seq_length, dtype, params): @pytest.mark.parametrize("n_share_experts_fusion", [0, 1, 8, 16])
def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusion):
num_experts, num_expert_group, topk_group, topk = params num_experts, num_expert_group, topk_group, topk = params
torch.manual_seed(seq_length) torch.manual_seed(seq_length)
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda() tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
scores = tensor.clone() scores = tensor.clone()
bias = torch.rand(num_experts).to(dtype).cuda() bias = torch.rand(num_experts).to(dtype).cuda()
topk = topk + min(1, n_share_experts_fusion)
output, indices = moe_fused_gate( output, indices = moe_fused_gate(
tensor, tensor,
...@@ -33,6 +35,8 @@ def test_moe_fused_gate_combined(seq_length, dtype, params): ...@@ -33,6 +35,8 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
topk=topk, topk=topk,
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=2.5,
) )
ref_output, ref_indices = biased_grouped_topk( ref_output, ref_indices = biased_grouped_topk(
scores, scores,
...@@ -43,8 +47,30 @@ def test_moe_fused_gate_combined(seq_length, dtype, params): ...@@ -43,8 +47,30 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
compiled=False, compiled=False,
n_share_experts_fusion=n_share_experts_fusion,
) )
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
if n_share_experts_fusion > 0:
original_indices = indices.clone()
original_ref_indices = ref_indices.clone()
indices = indices[:, :-1]
ref_indices = ref_indices[:, :-1]
valid_min = num_experts
valid_max = num_experts + n_share_experts_fusion
shared_indices = original_indices[:, -1]
shared_ref_indices = original_ref_indices[:, -1]
if shared_indices is not None:
assert torch.all(
(shared_indices >= valid_min) & (shared_indices < valid_max)
), f"Shared expert indices out of range: found values outside [{valid_min}, {valid_max})"
if shared_ref_indices is not None:
assert torch.all(
(shared_ref_indices >= valid_min) & (shared_ref_indices < valid_max)
), f"Shared expert reference indices out of range: found values outside [{valid_min}, {valid_max})"
idx_check = torch.allclose( idx_check = torch.allclose(
ref_indices.sort()[0].to(torch.int32), ref_indices.sort()[0].to(torch.int32),
indices.sort()[0].to(torch.int32), indices.sort()[0].to(torch.int32),
...@@ -54,17 +80,17 @@ def test_moe_fused_gate_combined(seq_length, dtype, params): ...@@ -54,17 +80,17 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
output_check = torch.allclose( output_check = torch.allclose(
ref_output.sort()[0].to(torch.float32), ref_output.sort()[0].to(torch.float32),
output.sort()[0].to(torch.float32), output.sort()[0].to(torch.float32),
rtol=1e-04, rtol=1e-02,
atol=1e-05, atol=1e-03,
) )
assert idx_check, ( assert idx_check, (
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, " f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}" f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
) )
assert output_check, ( assert output_check, (
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, " f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}" f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment