Commit 0b229519 authored by 王敏's avatar 王敏
Browse files

[feat]适配sgl moe_fused_gate kernel

parent 1150b65c
...@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) ...@@ -621,7 +621,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp" "csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu") "csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/moe_fused_gate.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
......
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include "../cuda_compat.h"
// #include <cutlass/array.h>
// #include <cutlass/cutlass.h>
// #include <cutlass/numeric_types.h>
#include <stdio.h>
#include <torch/all.h>
#include <cfloat>
#include <type_traits>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
/// Aligned array type
template <
typename T,
/// Number of elements in the array
int N,
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N
>
class alignas(Alignment) AlignedArray {
T data[N];
public:
__device__ T& operator[](int index) {
return data[index];
}
__device__ const T& operator[](int index) const {
return data[index];
}
};
// template <typename T, int N>
// using AlignedArray = cutlass::AlignedArray<T, N>;
// using bfloat16_t = cutlass::bfloat16_t;
// using float16_t = cutlass::half_t;
using float32_t = float;
// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" matches these operands: built-in operator
// "arithmetic > arithmetic" function "operator>(const __half &, const __half &)"
template <typename T>
__device__ inline bool cmp_gt(const T& a, const T& b) {
if constexpr (std::is_same<T, at::Half>::value) {
// at::Half (or float16_t in our native case) causes ambiguity, so we cast to float.
return static_cast<float>(a) > static_cast<float>(b);
} else {
// For types like float, at::BFloat16, or cutlass::half_t / cutlass::bfloat16_t, assume operator> works as expected.
return a > b;
}
}
template <typename T>
__device__ inline bool cmp_eq(const T& a, const T& b) {
if constexpr (std::is_same<T, at::Half>::value) {
return static_cast<float>(a) == static_cast<float>(b);
} else {
return a == b;
}
}
// Fixed constants common to both dynamic and static template versions:
//static constexpr int WARP_SIZE = 32;
static constexpr int WARPS_PER_CTA = 6;
static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
// Create an alias for Array using AlignedArray
template <typename T, int N>
using Array = AlignedArray<T, N>;
// QQ: NOTE expression must have a constant value, this has to be > params.VPT
template <typename T>
using AccessType = AlignedArray<T, MAX_VPT>;
template <typename T, typename Params>
__device__ void moe_fused_gate_impl(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor,
Params params) {
int tidx = threadIdx.x;
int64_t thread_row =
blockIdx.x * params.ROWS_PER_CTA + threadIdx.y * params.ROWS_PER_WARP + tidx / params.THREADS_PER_ROW;
if (thread_row >= num_rows) {
return;
}
// Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0);
// Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input);
auto* bias_ptr = reinterpret_cast<T*>(bias);
auto* thread_row_ptr = input_ptr + thread_row * params.NUM_EXPERTS;
int thread_group_idx = tidx % params.THREADS_PER_ROW;
int first_elt_read_by_thread = thread_group_idx * params.VPT;
// Create local arrays for the row chunk and bias chunk and then reinterpret the address of row_chunk as a pointer to
// AccessType.
T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
Array<T, MAX_VPT> row_chunk;
// T row_chunk[params.VPT];
AccessType<T> const* vec_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(thread_read_ptr);
T* bias_thread_read_ptr = bias_ptr + first_elt_read_by_thread;
Array<T, MAX_VPT> bias_chunk;
// T bias_chunk[params.VPT];
AccessType<T> const* vec_bias_thread_read_ptr = reinterpret_cast<AccessType<T> const*>(bias_thread_read_ptr);
//AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
//AccessType<T>* bias_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&bias_chunk);
// QQ NOTE: doing the follow will be slower than loop assign and more importantly
// have misaligned address issue when params.VPT < 8 and mismatch with MAX_VPT
// AccessType<T>* row_chunk_vec_ptr = reinterpret_cast<AccessType<T>*>(&row_chunk);
// row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
row_chunk[ii] = vec_thread_read_ptr[0][ii];
bias_chunk[ii] = vec_bias_thread_read_ptr[0][ii];
}
/*row_chunk_vec_ptr[0] = vec_thread_read_ptr[0];
bias_chunk_vec_ptr[0] = vec_bias_thread_read_ptr[0];*/
__syncthreads();
////////////////////// Sigmoid //////////////////////
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
row_chunk[ii] = static_cast<T>(1.0f / (1.0f + expf(-float(row_chunk[ii]))));
}
__syncthreads();
////////////////////// Add Bias //////////////////////
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
bias_chunk[ii] = row_chunk[ii] + bias_chunk[ii];
}
////////////////////// Exclude Groups //////////////////////
#pragma unroll
for (int k_idx = 0; k_idx < params.THREADS_PER_ROW - topk_group;
++k_idx) { // QQ NOTE Here params.THREADS_PER_ROW = num_expert_group
int expert = first_elt_read_by_thread;
// local argmax
T max_val = static_cast<T>(-FLT_MAX);
T max_val_second = static_cast<T>(-FLT_MAX);
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
T val = bias_chunk[ii];
if (cmp_gt(val, max_val)) {
max_val_second = max_val;
max_val = val;
} else if (cmp_gt(val, max_val_second)) {
max_val_second = val;
}
}
// QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert group and sum them as the group weight
// to select expert groups
T max_sum = max_val + max_val_second;
// argmin reduce
#pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max_sum =
static_cast<T>(VLLM_SHFL_XOR_SYNC_WIDTH(static_cast<float>(max_sum), mask, params.THREADS_PER_ROW));
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, params.THREADS_PER_ROW);
// higher indices win
if (cmp_gt(max_sum, other_max_sum) || (cmp_eq(other_max_sum, max_sum) && other_expert > expert)) {
max_sum = other_max_sum;
expert = other_expert;
}
}
// clear the max value in the thread
if (k_idx < params.THREADS_PER_ROW - topk_group) {
int const thread_to_clear_in_group = expert / params.VPT;
if (thread_group_idx == thread_to_clear_in_group) {
#pragma unroll
for (int ii = 0; ii < params.VPT; ++ii) {
bias_chunk[ii] = static_cast<T>(FLT_MAX);
}
}
}
}
__syncthreads();
////////////////////// Topk //////////////////////
float output_sum = 0.0f;
for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) {
// local argmax
T max_val = bias_chunk[0];
int expert = first_elt_read_by_thread;
if (!cmp_eq(max_val, static_cast<T>(FLT_MAX))) {
#pragma unroll
for (int ii = 1; ii < params.VPT; ++ii) {
T val = bias_chunk[ii];
if (cmp_gt(val, max_val)) {
max_val = val;
expert = first_elt_read_by_thread + ii;
}
}
} else {
max_val = static_cast<T>(-FLT_MAX);
}
// argmax reduce
#pragma unroll
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
T other_max =
static_cast<T>(VLLM_SHFL_XOR_SYNC_WIDTH(static_cast<float>(max_val), mask, params.THREADS_PER_ROW));
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, params.THREADS_PER_ROW);
// lower indices to win
if (cmp_gt(other_max, max_val) || (cmp_eq(other_max, max_val) && other_expert < expert)) {
max_val = other_max;
expert = other_expert;
}
}
int thread_to_clear_in_group = expert / params.VPT;
int64_t idx = topk * thread_row + k_idx;
if (thread_group_idx == thread_to_clear_in_group) {
int expert_to_clear_in_thread = expert % params.VPT;
// clear the max value in the thread
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
// store output
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
indices_ptr[idx] = static_cast<int32_t>(expert);
}
// accumulate sum for all elements
if (thread_group_idx == 0) {
output_sum += output_ptr[idx];
}
__syncthreads();
}
if (thread_group_idx == 0 && n_share_experts_fusion > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
// Use round-robin to select expert
int64_t expert_offset = thread_row % n_share_experts_fusion;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
}
__syncthreads();
////////////////////// Rescale Output //////////////////////
if (thread_group_idx == 0) {
#pragma unroll
for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = output_ptr[idx] / output_sum;
}
}
}
//------------------------------------------------------------------------------
// Templated Kernel Version (using compile-time constants)
//------------------------------------------------------------------------------
template <int VPT_, int NUM_EXPERTS_, int THREADS_PER_ROW_, int ROWS_PER_WARP_, int ROWS_PER_CTA_, int WARPS_PER_CTA_>
struct KernelParams {
static constexpr int VPT = VPT_;
static constexpr int NUM_EXPERTS = NUM_EXPERTS_;
static constexpr int THREADS_PER_ROW = THREADS_PER_ROW_;
static constexpr int ROWS_PER_WARP = ROWS_PER_WARP_;
static constexpr int ROWS_PER_CTA = ROWS_PER_CTA_;
static constexpr int WARPS_PER_CTA = WARPS_PER_CTA_;
};
template <
typename T,
int VPT,
int NUM_EXPERTS,
int THREADS_PER_ROW,
int ROWS_PER_WARP,
int ROWS_PER_CTA,
int WARPS_PER_CTA>
__global__ void moe_fused_gate_kernel(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
}
// Macro to compute compile-time constants and launch the kernel.
#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \
do { \
constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \
/* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */ \
constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
moe_fused_gate_kernel<T, VPT, (EXPERTS), (EXPERT_GROUP), ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> \
<<<num_blocks, block_dim, 0, stream>>>( \
input.data_ptr(), \
bias.data_ptr(), \
output.data_ptr<float>(), \
indices.data_ptr<int32_t>(), \
num_rows, \
topk_group, \
topk, \
n_share_experts_fusion, \
routed_scaling_factor); \
dispatched = true; \
} while (0)
//------------------------------------------------------------------------------
// Dynamic Kernel Version (parameters computed at runtime)
//------------------------------------------------------------------------------
struct KernelParamsDynamic {
int VPT;
int NUM_EXPERTS;
int THREADS_PER_ROW;
int ROWS_PER_WARP;
int ROWS_PER_CTA;
int WARPS_PER_CTA;
};
template <typename T>
__global__ void moe_fused_gate_kernel_dynamic(
void* input,
void* bias,
float* output_ptr,
int32_t* indices_ptr,
int64_t num_rows,
int64_t num_experts,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
params.THREADS_PER_ROW = num_expert_group; // fixed as num_expert_group, e.g., for deepseek v3, this is 8
params.WARPS_PER_CTA = WARPS_PER_CTA; // fixed as 6
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
}
//------------------------------------------------------------------------------
// Host Launcher Function
//------------------------------------------------------------------------------
std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input,
at::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto output = torch::empty({num_rows, topk}, options);
auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32));
// Compute grid dimensions based on runtime value for num_expert_group.
int64_t rows_per_warp = std::max<int64_t>(1, WARP_SIZE / num_expert_group);
int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp;
int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block_dim(WARP_SIZE, WARPS_PER_CTA);
// Check 1: Ensure that num_experts is a power of 2.
TORCH_CHECK((num_experts & (num_experts - 1)) == 0, "num_experts must be a power of 2, but got ", num_experts);
// Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2)
TORCH_CHECK(
num_experts % num_expert_group == 0,
"num_experts must be divisible by num_expert_group, but got ",
num_experts,
" / ",
num_expert_group);
int computed_vpt = num_experts / num_expert_group;
// Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per
// threads we can process.
TORCH_CHECK(
computed_vpt <= MAX_VPT,
"Per group experts: num_experts / num_expert_group = (",
computed_vpt,
") exceeds the maximum supported (",
MAX_VPT,
")");
// Dispatch to templated kernel for known compile-time configurations.
// We currently only support for:
// Case 1: 256 experts, with 8 or 16 groups.
// Case 2: 128 experts, with 4 or 8 groups.
// Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32
bool dispatched = false;
switch (num_experts) {
case 256:
if (num_expert_group == 8)
// This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(__nv_bfloat16, 256, 8);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(half, 256, 8);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float, 256, 8);
} else if (num_expert_group == 16)
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(__nv_bfloat16, 256, 16);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(half, 256, 16);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float, 256, 16);
}
break;
case 128:
if (num_expert_group == 4)
// VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(__nv_bfloat16, 128, 4);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(half, 128, 4);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float, 128, 4);
} else if (num_expert_group == 8)
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
if (input.scalar_type() == at::kBFloat16) {
LAUNCH_MOE_GATE_CONFIG(__nv_bfloat16, 128, 8);
} else if (input.scalar_type() == at::kHalf) {
LAUNCH_MOE_GATE_CONFIG(half, 128, 8);
} else if (input.scalar_type() == at::kFloat) {
LAUNCH_MOE_GATE_CONFIG(float, 128, 8);
}
break;
default:
break;
}
if (!dispatched) {
// Fallback to the dynamic kernel if none of the supported combinations match.
// currently only support num_experts / num_expert_group <= 32 for dynamic kernels
if (input.scalar_type() == at::kBFloat16) {
moe_fused_gate_kernel_dynamic<__nv_bfloat16><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<half><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
bias.data_ptr(),
output.data_ptr<float>(),
indices.data_ptr<int32_t>(),
num_rows,
num_experts,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
}
}
return {output, indices};
}
...@@ -28,4 +28,13 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, ...@@ -28,4 +28,13 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor num_tokens_post_pad, int64_t top_k, torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit); int64_t BLOCK_SIZE_K, int64_t bit);
#endif #endif
\ No newline at end of file
std::vector<torch::Tensor> moe_fused_gate(
torch::Tensor& input,
torch::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor);
\ No newline at end of file
...@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -31,6 +31,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"); " Tensor! num_tokens_post_pad) -> ()");
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
#ifndef USE_ROCM #ifndef USE_ROCM
m.def( m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
......
...@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache( ...@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache(
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, # torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale) # seq_lens, page_table, scale)
# return out # return out
def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion=0,
routed_scaling_factor=0,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return torch.ops._moe_C.moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import functools import functools
import json import json
import os import os
import math
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
...@@ -1182,6 +1183,10 @@ def fused_topk( ...@@ -1182,6 +1183,10 @@ def fused_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()
# This is used by the Deepseek-V2 and Deepseek-V3 model # This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def grouped_topk( def grouped_topk(
......
...@@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -23,6 +23,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm import _custom_ops as ops
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_moe import fused_experts from .fused_moe import fused_experts
...@@ -222,7 +223,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -222,7 +223,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor if hasattr(self, "routed_scaling_factor") else None)
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
...@@ -436,6 +438,7 @@ class FusedMoE(torch.nn.Module): ...@@ -436,6 +438,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
): ):
super().__init__() super().__init__()
...@@ -505,6 +508,7 @@ class FusedMoE(torch.nn.Module): ...@@ -505,6 +508,7 @@ class FusedMoE(torch.nn.Module):
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.activation = activation self.activation = activation
self.routed_scaling_factor = routed_scaling_factor
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
...@@ -554,6 +558,7 @@ class FusedMoE(torch.nn.Module): ...@@ -554,6 +558,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
setattr(self.quant_method, "routed_scaling_factor", self.routed_scaling_factor)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
...@@ -839,23 +844,39 @@ class FusedMoE(torch.nn.Module): ...@@ -839,23 +844,39 @@ class FusedMoE(torch.nn.Module):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None): e_score_correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,):
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk) fused_topk, grouped_topk, is_power_of_two)
# DeekSeekv2 uses grouped_top_k # DeekSeekv2 uses grouped_top_k
if use_grouped_topk: if use_grouped_topk:
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk( if e_score_correction_bias is not None \
hidden_states=hidden_states, and router_logits.shape[1] // num_expert_group <= 32 \
gating_output=router_logits, and is_power_of_two(e_score_correction_bias.shape[0]):
topk=top_k,
renormalize=renormalize, # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
num_expert_group=num_expert_group, topk_weights, topk_ids = ops.moe_fused_gate(
topk_group=topk_group, router_logits,
scoring_func=scoring_func, e_score_correction_bias,
e_score_correction_bias=e_score_correction_bias) num_expert_group,
topk_group,
top_k,
routed_scaling_factor=routed_scaling_factor,
n_share_experts_fusion=0,
)
else:
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
...@@ -926,7 +947,7 @@ class FusedMoE(torch.nn.Module): ...@@ -926,7 +947,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe
) )
if self.dp_size > 1: if self.dp_size > 1:
......
...@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module):
topk_group=config.topk_group, topk_group=config.topk_group,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,) e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment