Unverified Commit 45dcfc2e authored by Qingquan Song's avatar Qingquan Song Committed by GitHub
Browse files

Add deepseek style fused moe group gate selection kernel (#4530)

parent ddf8981d
...@@ -151,6 +151,7 @@ set(SOURCES ...@@ -151,6 +151,7 @@ set(SOURCES
"csrc/gemm/per_token_group_quant_8bit.cu" "csrc/gemm/per_token_group_quant_8bit.cu"
"csrc/gemm/per_token_quant_fp8.cu" "csrc/gemm/per_token_quant_fp8.cu"
"csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/speculative/eagle_utils.cu" "csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu" "csrc/speculative/speculative_sampling.cu"
......
import itertools
import math
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_fused_gate
from sglang.srt.layers.moe.topk import biased_grouped_topk
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
return biased_grouped_topk(
scores,
scores,
bias,
topk=topk,
renormalize=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
def biased_grouped_topk_org_kernel(scores, bias, num_expert_group, topk_group, topk):
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
configs = [(sq,) for sq in seq_length_range]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["seq_length"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["original", "kernel"],
line_names=["Original", "SGL Kernel"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name="moe-fused-gate-performance",
args={},
)
)
def benchmark(seq_length, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
bias = torch.rand(num_experts, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "original":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: biased_grouped_topk_org(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
elif provider == "kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: biased_grouped_topk_org_kernel(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <stdio.h>
#include <torch/all.h>
#include <cfloat>
#include <type_traits>
template <typename T, int N>
using AlignedArray = cutlass::AlignedArray<T, N>;
using bfloat16_t = cutlass::bfloat16_t;
using float16_t = cutlass::half_t;
using float32_t = float;
// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" matches these operands: built-in operator
// "arithmetic > arithmetic" function "operator>(const __half &, const __half &)"
template <typename T>
__device__ inline bool cmp_gt(const T& a, const T& b) {
if constexpr (std::is_same<T, at::Half>::value) {
// at::Half (or float16_t in our native case) causes ambiguity, so we cast to float.
return static_cast<float>(a) > static_cast<float>(b);
} else {
// For types like float, at::BFloat16, or cutlass::half_t / cutlass::bfloat16_t, assume operator> works as expected.
return a > b;
}
}
template <typename T>
__device__ inline bool cmp_eq(const T& a, const T& b) {
if constexpr (std::is_same<T, at::Half>::value) {
return static_cast<float>(a) == static_cast<float>(b);
} else {
return a == b;
}
}
// Fixed constants common to both dynamic and static template versions:
static constexpr int WARP_SIZE = 32;
static constexpr int WARPS_PER_CTA = 6;
static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
// Create an alias for Array using AlignedArray
template <typename T, int N>
using Array = AlignedArray<T, N>;
// QQ: NOTE expression must have a constant value, this has to be > params.VPT
template <typename T>
using AccessType = AlignedArray<T, MAX_VPT>;
template <typename T, typename Params>
__device__ void moe_fused_gate_impl(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk,
Params params) {
int tidx = threadIdx.x;
int64_t thread_row =
blockIdx.x * params.ROWS_PER_CTA + threadIdx.y * params.ROWS_PER_WARP + tidx / params.THREADS_PER_ROW;
if (thread_row >= num_rows) {
return;
}
// Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input);
auto* bias_ptr = reinterpret_cast<T*>(bias);
auto* thread_row_ptr = input_ptr + thread_row * params.NUM_EXPERTS;
int thread_group_idx = tidx % params.THREADS_PER_ROW;
int first_elt_read_by_thread = thread_group_idx * params.VPT;
// Create local arrays for the row chunk and bias chunk and then reinterpret the address of row_chunk as a pointer to
// AccessType.
T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
Array<T, MAX_VPT> row_chunk;
AccessType<T> const* vec_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(thread_read_ptr);
T* bias_thread_read_ptr = bias_ptr + first_elt_read_by_thread;
Array<T, MAX_VPT> bias_chunk;
AccessType<T> const* vec_bias_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(bias_thread_read_ptr);
// QQ NOTE: doing the follow will be slower than loop assign and more importantly
// have misaligned address issue when params.VPT < 8 and mismatch with MAX_VPT
// AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
// row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
row_chunk[ii] = vec_thread_read_ptr[0][ii];
bias_chunk[ii] = vec_bias_thread_read_ptr[0][ii];
}
__syncthreads();
////////////////////// Sigmoid //////////////////////
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
row_chunk[ii] = static_cast<T>(1.0f / (1.0f + expf(-float(row_chunk[ii]))));
}
__syncthreads();
////////////////////// Add Bias //////////////////////
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
bias_chunk[ii] = row_chunk[ii] + bias_chunk[ii];
}
////////////////////// Exclude Groups //////////////////////
#pragma unroll
for (int k_idx = 0; k_idx < params.THREADS_PER_ROW - topk_group;
++k_idx) { // QQ NOTE Here params.THREADS_PER_ROW = num_expert_group
int expert = first_elt_read_by_thread;
// local argmax
T max_val = static_cast<T>(-FLT_MAX);
T max_val_second = static_cast<T>(-FLT_MAX);
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
T val = bias_chunk[ii];
if (cmp_gt(val, max_val)) {
max_val_second = max_val;
max_val = val;
} else if (cmp_gt(val, max_val_second)) {
max_val_second = val;
}
}
// QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert group and sum them as the group weight
// to select expert groups
T max_sum = max_val + max_val_second;
// argmin reduce
#pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max_sum =
static_cast<T>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(max_sum), mask, params.THREADS_PER_ROW));
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW);
// higher indices win
if (cmp_gt(max_sum, other_max_sum) || (cmp_eq(other_max_sum, max_sum) && other_expert > expert)) {
max_sum = other_max_sum;
expert = other_expert;
}
}
// clear the max value in the thread
if (k_idx < params.THREADS_PER_ROW - topk_group) {
int const thread_to_clear_in_group = expert / params.VPT;
if (thread_group_idx == thread_to_clear_in_group) {
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
bias_chunk[ii] = static_cast<T>(FLT_MAX);
}
}
}
}
__syncthreads();
////////////////////// Topk //////////////////////
float output_sum = 0.0f;
for (int k_idx = 0; k_idx < topk; ++k_idx) {
// local argmax
T max_val = bias_chunk[0];
int expert = first_elt_read_by_thread;
if (!cmp_eq(max_val, static_cast<T>(FLT_MAX))) {
#pragma unroll
for (int ii = 1; ii < params.VPT; ++ii) {
T val = bias_chunk[ii];
if (cmp_gt(val, max_val)) {
max_val = val;
expert = first_elt_read_by_thread + ii;
}
}
} else {
max_val = static_cast<T>(-FLT_MAX);
}
// argmax reduce
#pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max =
static_cast<T>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(max_val), mask, params.THREADS_PER_ROW));
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW);
// lower indices to win
if (cmp_gt(other_max, max_val) || (cmp_eq(other_max, max_val) && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
}
}
if (k_idx < topk) {
int thread_to_clear_in_group = expert / params.VPT;
int64_t idx = topk * thread_row + k_idx;
if (thread_group_idx == thread_to_clear_in_group) {
int expert_to_clear_in_thread = expert % params.VPT;
// clear the max value in the thread
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
// store output
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
indices_ptr[idx] = static_cast<int32_t>(expert);
}
// accumulate sum
if (thread_group_idx == 0) {
output_sum += output_ptr[idx];
}
}
__syncthreads();
}
////////////////////// Rescale Output //////////////////////
if (thread_group_idx == 0) {
#pragma unroll
for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = static_cast<float>(static_cast<T>(output_ptr[idx]) / static_cast<T>(output_sum));
}
}
}
//------------------------------------------------------------------------------
// Templated Kernel Version (using compile-time constants)
//------------------------------------------------------------------------------
template <int VPT_, int NUM_EXPERTS_, int THREADS_PER_ROW_, int ROWS_PER_WARP_, int ROWS_PER_CTA_, int WARPS_PER_CTA_>
struct KernelParams {
static constexpr int VPT = VPT_;
static constexpr int NUM_EXPERTS = NUM_EXPERTS_;
static constexpr int THREADS_PER_ROW = THREADS_PER_ROW_;
static constexpr int ROWS_PER_WARP = ROWS_PER_WARP_;
static constexpr int ROWS_PER_CTA = ROWS_PER_CTA_;
static constexpr int WARPS_PER_CTA = WARPS_PER_CTA_;
};
template <
typename T,
int VPT,
int NUM_EXPERTS,
int THREADS_PER_ROW,
int ROWS_PER_WARP,
int ROWS_PER_CTA,
int WARPS_PER_CTA>
__global__ void moe_fused_gate_kernel(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
}
// Macro to compute compile-time constants and launch the kernel.
#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \
do { \
constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \
/* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */ \
constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
moe_fused_gate_kernel<T, VPT, (EXPERTS), (EXPERT_GROUP), ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> \
<<<num_blocks, block_dim, 0, stream>>>( \
input.data_ptr(), \
bias.data_ptr(), \
output.data_ptr<float>(), \
indices.data_ptr<int32_t>(), \
num_rows, \
topk_group, \
topk); \
dispatched = true; \
} while (0)
//------------------------------------------------------------------------------
// Dynamic Kernel Version (parameters computed at runtime)
//------------------------------------------------------------------------------
struct KernelParamsDynamic {
int VPT;
int NUM_EXPERTS;
int THREADS_PER_ROW;
int ROWS_PER_WARP;
int ROWS_PER_CTA;
int WARPS_PER_CTA;
};
template <typename T>
__global__ void moe_fused_gate_kernel_dynamic(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t num_experts,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
params.THREADS_PER_ROW = num_expert_group; // fixed as num_expert_group, e.g., for deepseek v3, this is 8
params.WARPS_PER_CTA = WARPS_PER_CTA; // fixed as 6
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
}
//------------------------------------------------------------------------------
// Host Launcher Function
//------------------------------------------------------------------------------
std::vector<at::Tensor>
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto output = torch::empty({num_rows, topk}, options);
auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32));
// Compute grid dimensions based on runtime value for num_expert_group.
int64_t rows_per_warp = std::max<int64_t>(1, WARP_SIZE / num_expert_group);
int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp;
int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block_dim(WARP_SIZE, WARPS_PER_CTA);
// Check 1: Ensure that num_experts is a power of 2.
TORCH_CHECK((num_experts & (num_experts - 1)) == 0, "num_experts must be a power of 2, but got ", num_experts);
// Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2)
TORCH_CHECK(
num_experts % num_expert_group == 0,
"num_experts must be divisible by num_expert_group, but got ",
num_experts,
" / ",
num_expert_group);
int computed_vpt = num_experts / num_expert_group;
// Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per
// threads we can process.
TORCH_CHECK(
computed_vpt <= MAX_VPT,
"Per group experts: num_experts / num_expert_group = (",
computed_vpt,
") exceeds the maximum supported (",
MAX_VPT,
")");
// Dispatch to templated kernel for known compile-time configurations.
// We currently only support for:
// Case 1: 256 experts, with 8 or 16 groups.
// Case 2: 128 experts, with 4 or 8 groups.
// Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32
bool dispatched = false;
switch (num_experts) {
case 256:
if (num_expert_group == 8)
// This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 8);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 8);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 8);
} else if (num_expert_group == 16)
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 16);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 16);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 16);
}
break;
case 128:
if (num_expert_group == 4)
// VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 4);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 4);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 4);
} else if (num_expert_group == 8)
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 8);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 8);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 8);
}
break;
default:
break;
}
if (!dispatched) {
// Fallback to the dynamic kernel if none of the supported combinations match.
// currently only support num_experts / num_expert_group <= 32 for dynamic kernels
if (input.scalar_type() == at::kBFloat16) {
moe_fused_gate_kernel_dynamic<bfloat16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
}
}
return {output, indices};
}
...@@ -138,6 +138,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -138,6 +138,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"token_expert_indices, Tensor gating_output) -> ()"); "token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
......
...@@ -199,6 +199,9 @@ void topk_softmax( ...@@ -199,6 +199,9 @@ void topk_softmax(
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output); torch::Tensor& gating_output);
std::vector<at::Tensor>
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk);
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
......
...@@ -36,7 +36,7 @@ from sgl_kernel.gemm import ( ...@@ -36,7 +36,7 @@ from sgl_kernel.gemm import (
sgl_per_token_group_quant_int8, sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8, sgl_per_token_quant_fp8,
) )
from sgl_kernel.moe import moe_align_block_size, topk_softmax from sgl_kernel.moe import moe_align_block_size, moe_fused_gate, topk_softmax
from sgl_kernel.sampling import ( from sgl_kernel.sampling import (
min_p_sampling_from_probs, min_p_sampling_from_probs,
top_k_renorm_prob, top_k_renorm_prob,
......
...@@ -32,3 +32,15 @@ def topk_softmax( ...@@ -32,3 +32,15 @@ def topk_softmax(
torch.ops.sgl_kernel.topk_softmax.default( torch.ops.sgl_kernel.topk_softmax.default(
topk_weights, topk_ids, token_expert_indices, gating_output topk_weights, topk_ids, token_expert_indices, gating_output
) )
def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
return torch.ops.sgl_kernel.moe_fused_gate(
input_tensor, bias, num_expert_group, topk_group, topk
)
...@@ -161,6 +161,7 @@ sources = [ ...@@ -161,6 +161,7 @@ sources = [
"csrc/gemm/per_token_quant_fp8.cu", "csrc/gemm/per_token_quant_fp8.cu",
"csrc/gemm/per_tensor_quant_fp8.cu", "csrc/gemm/per_tensor_quant_fp8.cu",
"csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_fused_gate.cu",
"csrc/moe/moe_topk_softmax_kernels.cu", "csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu", "csrc/speculative/eagle_utils.cu",
"csrc/speculative/speculative_sampling.cu", "csrc/speculative/speculative_sampling.cu",
......
import pytest
import torch
from sgl_kernel import moe_fused_gate
from sglang.srt.layers.moe.topk import biased_grouped_topk
@pytest.mark.parametrize(
"seq_length",
list(range(1, 10))
+ [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536],
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16])
@pytest.mark.parametrize(
"params",
[
(128, 4, 2, 4),
(256, 8, 4, 8), # deepseek v3
(512, 16, 8, 16),
],
)
def test_moe_fused_gate_combined(seq_length, dtype, params):
num_experts, num_expert_group, topk_group, topk = params
torch.manual_seed(seq_length)
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
scores = tensor.clone()
bias = torch.rand(num_experts).to(dtype).cuda()
output, indices = moe_fused_gate(
tensor,
bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
topk=topk,
)
ref_output, ref_indices = biased_grouped_topk(
scores,
scores,
bias,
topk=topk,
renormalize=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
compiled=False,
)
idx_check = torch.allclose(
ref_indices.sort()[0].to(torch.int32),
indices.sort()[0].to(torch.int32),
rtol=1e-04,
atol=1e-05,
)
output_check = torch.allclose(
ref_output.sort()[0].to(torch.float32),
output.sort()[0].to(torch.float32),
rtol=1e-04,
atol=1e-05,
)
assert idx_check, (
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}"
)
assert output_check, (
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}"
)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment