Unverified Commit c0652d90 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up sgl kernel (#12413)


Co-authored-by: default avatarByron Hsu <byronhsu1230@gmail.com>
parent 2e48584b
......@@ -37,7 +37,11 @@ from pydantic import (
model_validator,
)
from typing_extensions import Literal
from xgrammar import StructuralTag
try:
from xgrammar import StructuralTag
except:
StructuralTag = Any
from sglang.utils import convert_json_schema_to_str
......
......@@ -42,7 +42,7 @@ endif()
find_package(Torch REQUIRED)
clear_cuda_arches(CMAKE_FLAG)
# Third Party
# Third Party repos
# cutlass
FetchContent_Declare(
repo-cutlass
......@@ -271,6 +271,8 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
)
endif()
# All source files
# NOTE: Please sort the filenames alphabetically
set(SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/mscclpp_allreduce.cu"
......@@ -279,16 +281,15 @@ set(SOURCES
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/common_extension.cc"
"csrc/elementwise/activation.cu"
"csrc/elementwise/cast.cu"
"csrc/elementwise/copy.cu"
"csrc/elementwise/concat_mla.cu"
"csrc/elementwise/copy.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu"
"csrc/elementwise/topk.cu"
"csrc/common_extension.cc"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/expert_specialization/es_fp8_blockwise.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
......@@ -314,10 +315,11 @@ set(SOURCES
"csrc/gemm/marlin/gptq_marlin_repack.cu"
"csrc/gemm/marlin/awq_marlin_repack.cu"
"csrc/gemm/gptq/gptq_kernel.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/mamba/causal_conv1d.cu"
"csrc/memory/store.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
......@@ -332,16 +334,12 @@ set(SOURCES
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/memory/store.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/ngram_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/expert_specialization/es_fp8_blockwise.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
......@@ -356,17 +354,7 @@ set(SOURCES
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
)
# =========================== Common SM90 Build ============================= #
# Build SM90 library with fast math optimization (same namespace, different directory)
Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_definitions(common_ops_sm90_build PRIVATE
USE_FAST_MATH=1
)
target_compile_options(common_ops_sm90_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math>
)
target_include_directories(common_ops_sm90_build PRIVATE
set(INCLUDES
${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
......@@ -376,6 +364,15 @@ target_include_directories(common_ops_sm90_build PRIVATE
${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
)
# =========================== Common SM90 Build ============================= #
# Build SM90 library with fast math optimization (same namespace, different directory)
Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_options(common_ops_sm90_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math>
)
target_include_directories(common_ops_sm90_build PRIVATE ${INCLUDES})
# Set output name and separate build directory to avoid conflicts
set_target_properties(common_ops_sm90_build PROPERTIES
OUTPUT_NAME "common_ops"
......@@ -386,22 +383,10 @@ set_target_properties(common_ops_sm90_build PROPERTIES
# Build SM100+ library with precise math (same namespace, different directory)
Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_definitions(common_ops_sm100_build PRIVATE
USE_FAST_MATH=0
)
target_compile_options(common_ops_sm100_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>
)
target_include_directories(common_ops_sm100_build PRIVATE
${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
)
target_include_directories(common_ops_sm100_build PRIVATE ${INCLUDES})
# Set output name and separate build directory to avoid conflicts
set_target_properties(common_ops_sm100_build PROPERTIES
OUTPUT_NAME "common_ops"
......@@ -432,6 +417,7 @@ add_subdirectory(
${repo-mscclpp_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build
)
target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
......@@ -453,7 +439,7 @@ target_compile_definitions(common_ops_sm100_build PRIVATE
install(TARGETS common_ops_sm90_build LIBRARY DESTINATION sgl_kernel/sm90)
install(TARGETS common_ops_sm100_build LIBRARY DESTINATION sgl_kernel/sm100)
# ============================ Optional Install ============================= #
# ============================ Optional Install: FA3 ============================= #
# set flash-attention sources file
# Now FA3 support sm80/sm86/sm90
if (SGL_KERNEL_ENABLE_FA3)
......@@ -553,10 +539,10 @@ target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERN
target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
# ============================ Extra Install ============================= #
# ============================ Extra Install: FLashMLA ============================= #
include(${CMAKE_CURRENT_LIST_DIR}/cmake/flashmla.cmake)
# ============================ DeepGEMM (JIT) ============================= #
# ============================ Extra Install: DeepGEMM (JIT) ============================= #
# Create a separate library for DeepGEMM's Python API.
# This keeps its compilation isolated from the main common_ops.
set(DEEPGEMM_SOURCES
......@@ -601,13 +587,13 @@ install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/"
DESTINATION "deep_gemm/include/cutlass")
# triton_kernels
# ============================ Extra Install: triton kernels ============================= #
install(DIRECTORY "${repo-triton_SOURCE_DIR}/python/triton_kernels/triton_kernels/"
DESTINATION "triton_kernels"
PATTERN ".git*" EXCLUDE
PATTERN "__pycache__" EXCLUDE)
# flash attention 4
# ============================ Extra Install: FA4 ============================= #
# TODO: find a better install condition.
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
# flash_attn/cute
......
......@@ -13,6 +13,8 @@ set(FLASHMLA_CUDA_FLAGS
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
"--use_fast_math"
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
)
# The FlashMLA kernels only work on hopper and require CUDA 12.4 or later.
......
......@@ -118,37 +118,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"topk_indices_offset) -> ()");
m.impl("fast_topk_transform_ragged_fused", torch::kCUDA, &fast_topk_transform_ragged_interface);
/*
* From gguf quantiztion
*/
m.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
m.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
m.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
m.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
m.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
m.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
m.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
m.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
m.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
/*
* From csrc/gemm
*/
......@@ -256,7 +225,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, float "
"moe_softcapping, Tensor? correction_bias) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def("moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()");
......@@ -289,22 +260,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
/*
* From csrc/moe/marlin_moe_wna16
*/
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
......@@ -324,6 +279,22 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/*
* From csrc/moe/marlin_moe_wna16
*/
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/*
* From csrc/speculative
*/
......@@ -521,6 +492,38 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor _ascales, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
/*
* From csrc/quantization/gguf
*/
m.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
m.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
m.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
m.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
m.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
m.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
m.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
m.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
m.def("ggml_moe_get_block_size(int type) -> int");
m.impl("ggml_moe_get_block_size", torch::kCUDA, &ggml_moe_get_block_size);
/*
* From csrc/mamba
*/
......@@ -556,7 +559,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
/*
* From hadamard-transform
* From fast-hadamard-transform
*/
m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform);
......
......@@ -70,7 +70,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
// quick allreduce
#ifdef USE_ROCM
m.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
......@@ -86,7 +85,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
// Max input size in bytes
m.def("qr_max_size", &qr_max_size);
#endif
/*
* From csrc/moe
......@@ -97,7 +95,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, float "
"moe_softcapping, Tensor? correction_bias) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
/*
......@@ -116,12 +116,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
/*
* From XGrammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
/*
* From csrc/kvcacheio
*/
......@@ -171,6 +165,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
"Tensor dst_indices, int page_size) ->() ");
m.impl("transfer_kv_all_layer_direct_lf_pf", torch::kCUDA, &transfer_kv_all_layer_direct_lf_pf);
/*
* From csrc/grammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
}
REGISTER_EXTENSION(common_ops)
......@@ -73,8 +73,13 @@ __device__ float convert_to_float(T x) {
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template <typename T, int TPB>
__launch_bounds__(TPB) __global__
void moeSoftmax(const T* input, const bool* finished, float* output, const int num_cols) {
__launch_bounds__(TPB) __global__ void moeSoftmax(
const T* input,
const bool* finished,
float* output,
const int num_cols,
const float moe_softcapping,
const float* correction_bias) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
......@@ -90,9 +95,23 @@ __launch_bounds__(TPB) __global__
return;
}
// First pass: Apply transformation, find max, and write transformed values to output
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData = max(convert_to_float<T>(input[idx]), threadData);
float val = convert_to_float<T>(input[idx]);
// Apply tanh softcapping if enabled
if (moe_softcapping != 0.0f) {
val = tanhf(val / moe_softcapping) * moe_softcapping;
}
// Apply correction bias if provided
if (correction_bias != nullptr) {
val = val + correction_bias[ii];
}
output[idx] = val; // Store transformed value
threadData = max(val, threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
......@@ -102,11 +121,11 @@ __launch_bounds__(TPB) __global__
}
__syncthreads();
// Second pass: Compute sum using transformed values from output
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
threadData += exp((output[idx] - float_max));
}
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
......@@ -116,10 +135,11 @@ __launch_bounds__(TPB) __global__
}
__syncthreads();
// Third pass: Compute final softmax using transformed values from output
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii;
const float val = exp((convert_to_float<T>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = val;
const float softmax_val = exp((output[idx] - float_max)) * normalizing_factor;
output[idx] = softmax_val;
}
}
......@@ -216,7 +236,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
const int k,
const int start_expert,
const int end_expert,
const bool renormalize) {
const bool renormalize,
const float moe_softcapping,
const float* correction_bias) {
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
......@@ -283,16 +305,48 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_temp);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll
// Note(Byron): interleaved loads to achieve better memory coalescing
// | thread[0] | thread[1] | thread[2] | thread[3] | thread[0] | thread[1] | thread[2] | thread[3] | ...
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
float row_chunk[VPT];
#pragma unroll
// Note(Byron): upcast logits to float32
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = convert_to_float<T>(row_chunk_temp[ii]);
}
// Apply tanh softcapping and correction bias
if (moe_softcapping != 0.0f || correction_bias != nullptr) {
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
float val = row_chunk[ii];
// Apply tanh softcapping if enabled
if (moe_softcapping != 0.0f) {
val = tanhf(val / moe_softcapping) * moe_softcapping;
}
// Apply correction bias if provided
if (correction_bias != nullptr) {
/*
LDG is interleaved
|thread0 LDG| |thread1 LDG| |thread0 LDG| |thread1 LDG|
|--------- group0 --------| |----------group1 --------|
^ local2
*/
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread + group_id * THREADS_PER_ROW * ELTS_PER_LDG + local_id;
val = val + correction_bias[expert_idx];
}
row_chunk[ii] = val;
}
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
float thread_max = row_chunk[0];
......@@ -301,9 +355,15 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
thread_max = max(thread_max, row_chunk[ii]);
}
/*********************************/
/********* Softmax Begin *********/
/*********************************/
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
// lane id: 0-31 within a warp
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
// butterfly reduce with (lane id ^ mask)
thread_max = max(thread_max, SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, thread_max, mask, THREADS_PER_ROW));
}
......@@ -333,6 +393,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
}
/*******************************/
/********* Softmax End *********/
/*******************************/
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index.
......@@ -438,6 +501,8 @@ void topkGatingSoftmaxLauncherHelper(
const int start_expert,
const int end_expert,
const bool renormalize,
const float moe_softcapping,
const float* correction_bias,
cudaStream_t stream) {
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
......@@ -450,12 +515,33 @@ void topkGatingSoftmaxLauncherHelper(
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, k, start_expert, end_expert, renormalize);
input,
finished,
output,
num_rows,
indices,
k,
start_expert,
end_expert,
renormalize,
moe_softcapping,
correction_bias);
}
#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream);
gating_output, \
nullptr, \
topk_weights, \
topk_indices, \
num_tokens, \
topk, \
0, \
num_experts, \
renormalize, \
moe_softcapping, \
correction_bias, \
stream);
template <typename T>
void topkGatingSoftmaxKernelLauncher(
......@@ -467,6 +553,8 @@ void topkGatingSoftmaxKernelLauncher(
const int num_experts,
const int topk,
const bool renormalize,
const float moe_softcapping,
const float* correction_bias,
cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
switch (num_experts) {
......@@ -502,7 +590,8 @@ void topkGatingSoftmaxKernelLauncher(
softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2.");
static constexpr int TPB = 256;
moeSoftmax<T, TPB><<<num_tokens, TPB, 0, stream>>>(gating_output, nullptr, softmax_workspace, num_experts);
moeSoftmax<T, TPB><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts, moe_softcapping, correction_bias);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize);
}
......@@ -510,11 +599,12 @@ void topkGatingSoftmaxKernelLauncher(
}
void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& gating_output,
const bool renormalize) // [num_tokens, num_experts]
{
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& gating_output, // [num_tokens, num_experts]
const bool renormalize,
const double moe_softcapping,
const c10::optional<torch::Tensor>& correction_bias) {
// Check data type
TORCH_CHECK(
gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half ||
......@@ -552,6 +642,23 @@ void topk_softmax(
torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float));
const at::ScalarType dtype = gating_output.scalar_type();
// Validate correction_bias if provided - must always be float32
const float* bias_ptr = nullptr;
if (correction_bias.has_value()) {
const torch::Tensor& bias_tensor = correction_bias.value();
TORCH_CHECK(bias_tensor.dim() == 1, "correction_bias must be 1D tensor [num_experts]");
TORCH_CHECK(bias_tensor.size(0) == num_experts, "correction_bias size must match num_experts");
TORCH_CHECK(
bias_tensor.scalar_type() == at::ScalarType::Float,
"correction_bias must be float32, got ",
bias_tensor.scalar_type());
bias_ptr = bias_tensor.data_ptr<float>();
}
// Cast moe_softcapping from double to float for CUDA kernels
const float moe_softcapping_f = static_cast<float>(moe_softcapping);
if (dtype == at::ScalarType::Float) {
topkGatingSoftmaxKernelLauncher<float>(
gating_output.data_ptr<float>(),
......@@ -562,6 +669,8 @@ void topk_softmax(
num_experts,
topk,
renormalize,
moe_softcapping_f,
bias_ptr,
stream);
} else if (dtype == at::ScalarType::Half) {
topkGatingSoftmaxKernelLauncher<__half>(
......@@ -573,6 +682,8 @@ void topk_softmax(
num_experts,
topk,
renormalize,
moe_softcapping_f,
bias_ptr,
stream);
} else if (dtype == at::ScalarType::BFloat16) {
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
......@@ -584,6 +695,8 @@ void topk_softmax(
num_experts,
topk,
renormalize,
moe_softcapping_f,
bias_ptr,
stream);
} else {
TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype);
......
......@@ -191,32 +191,6 @@ void fast_topk_transform_ragged_interface(
void gelu_quick(at::Tensor& out, const at::Tensor& input);
#endif
/*
* From gguf quantization
*/
torch::Tensor
ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n, std::optional<at::ScalarType> const& dtype);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_moe_a8(
torch::Tensor X,
torch::Tensor W,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded,
int64_t type,
int64_t row,
int64_t top_k,
int64_t tokens);
torch::Tensor ggml_moe_a8_vec(
torch::Tensor X, torch::Tensor W, torch::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
/*
* From csrc/gemm
*/
......@@ -333,7 +307,12 @@ void moe_align_block_size(
bool pad_sorted_token_ids);
void topk_softmax(
torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize);
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& gating_output,
bool renormalize,
double moe_softcapping,
const c10::optional<torch::Tensor>& correction_bias);
void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor);
......@@ -417,6 +396,7 @@ void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor const& input_global_scale,
torch::Tensor const& mask,
bool use_silu_and_mul);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
......@@ -445,7 +425,9 @@ void cutlass_w4a8_moe_mm(
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
/*
* From csrc/moe/marlin_moe_wna16
*/
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
......@@ -680,6 +662,11 @@ void transfer_kv_all_layer_direct_lf_pf(
const at::Tensor& dst_indices,
int64_t page_size);
/*
* From csrc/memory
*/
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v);
/*
* From FlashInfer
*/
......@@ -798,12 +785,12 @@ void convert_vertical_slash_indexes_mergehead(
bool causal);
/*
* From XGrammar
* From csrc/grammar
*/
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt);
/*
* From QServe
* From csrc/gemm (QServe)
*/
void qserve_w4a8_per_chn_gemm(
const torch::Tensor& _in_feats,
......@@ -824,14 +811,35 @@ void qserve_w4a8_per_group_gemm(
torch::Tensor& _out_feats);
/*
* From csrc/spatial
* From csrc/quantization/gguf
*/
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device);
torch::Tensor
ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n, std::optional<at::ScalarType> const& dtype);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_moe_a8(
torch::Tensor X,
torch::Tensor W,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded,
int64_t type,
int64_t row,
int64_t top_k,
int64_t tokens);
torch::Tensor ggml_moe_a8_vec(
torch::Tensor X, torch::Tensor W, torch::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
/*
* From csrc/memory
* From csrc/spatial
*/
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v);
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device);
/*
* From csrc/mamba
......@@ -883,7 +891,7 @@ torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
/*
* From csrc/fastertransformer
* From flashmla
*/
std::vector<at::Tensor> get_mla_decoding_metadata(
at::Tensor& seqlens_k,
......
import ctypes
import logging
import os
import shutil
from pathlib import Path
from typing import List
import torch
logger = logging.getLogger(__name__)
def _get_compute_capability():
"""Get the compute capability of the current GPU."""
if not torch.cuda.is_available():
return None
# Get the current device
device = torch.cuda.current_device()
properties = torch.cuda.get_device_properties(device)
# Return as integer (major * 10 + minor)
return properties.major * 10 + properties.minor
def _filter_compiled_extensions(file_list):
"""Filter and prioritize compiled extensions over Python source files."""
compiled_extensions = [".so", ".pyd", ".dll"] # Common compiled extension suffixes
compiled_files = []
other_files = []
for file_path in file_list:
path = Path(file_path)
# Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so)
if any(
str(path).endswith(ext) or ext in str(path) for ext in compiled_extensions
):
compiled_files.append(file_path)
else:
other_files.append(file_path)
# Return compiled files first, then others
return compiled_files + other_files
def _load_architecture_specific_ops():
"""Load the appropriate common_ops library based on GPU architecture."""
import importlib.util
import sys
from pathlib import Path
compute_capability = _get_compute_capability()
logger.debug(
f"[sgl_kernel] GPU Detection: compute_capability = {compute_capability}"
)
# Get the directory where sgl_kernel is installed
sgl_kernel_dir = Path(__file__).parent
logger.debug(f"[sgl_kernel] sgl_kernel directory: {sgl_kernel_dir}")
# Determine which version to load based on GPU architecture
if compute_capability == 90:
ops_subdir = "sm90"
variant_name = "SM90 (Hopper/H100 with fast math optimization)"
elif compute_capability is not None:
ops_subdir = "sm100"
variant_name = f"SM{compute_capability} (precise math for compatibility)"
else:
ops_subdir = "sm100"
variant_name = "CPU/No GPU detected (using precise math)"
# Look for the compiled module with any valid extension
import glob
ops_pattern = str(sgl_kernel_dir / ops_subdir / "common_ops.*")
raw_matching_files = glob.glob(ops_pattern)
matching_files = _filter_compiled_extensions(raw_matching_files)
logger.debug(f"[sgl_kernel] Attempting to load {variant_name}")
logger.debug(f"[sgl_kernel] Looking for library matching pattern: {ops_pattern}")
logger.debug(f"[sgl_kernel] Found files: {raw_matching_files}")
logger.debug(f"[sgl_kernel] Prioritized files: {matching_files}")
previous_import_errors: List[Exception] = []
# Try to load from the architecture-specific directory
if matching_files:
ops_path = Path(matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found architecture-specific library: {ops_path}")
try:
# Load the module from specific path using importlib
spec = importlib.util.spec_from_file_location("common_ops", str(ops_path))
if spec is None:
raise ImportError(f"Could not create module spec for {ops_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {ops_path}")
logger.debug(f"[sgl_kernel] Loading module from {ops_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded {variant_name}")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
previous_import_errors.append(e)
logger.debug(
f"[sgl_kernel] ✗ Failed to load from {ops_path}: {type(e).__name__}: {e}"
)
# Continue to fallback
else:
logger.debug(
f"[sgl_kernel] ✗ Architecture-specific library not found matching pattern: {ops_pattern}"
)
# Try alternative directory (in case installation structure differs)
alt_pattern = str(sgl_kernel_dir / "common_ops.*")
raw_alt_files = glob.glob(alt_pattern)
alt_matching_files = _filter_compiled_extensions(raw_alt_files)
logger.debug(f"[sgl_kernel] Attempting fallback: looking for pattern {alt_pattern}")
logger.debug(f"[sgl_kernel] Found fallback files: {raw_alt_files}")
logger.debug(f"[sgl_kernel] Prioritized fallback files: {alt_matching_files}")
if alt_matching_files:
alt_path = Path(alt_matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found fallback library: {alt_path}")
try:
spec = importlib.util.spec_from_file_location("common_ops", str(alt_path))
if spec is None:
raise ImportError(f"Could not create module spec for {alt_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {alt_path}")
logger.debug(f"[sgl_kernel] Loading fallback module from {alt_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded fallback library")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
previous_import_errors.append(e)
logger.debug(
f"[sgl_kernel] ✗ Failed to load fallback from {alt_path}: {type(e).__name__}: {e}"
)
else:
logger.debug(
f"[sgl_kernel] ✗ Fallback library not found matching pattern: {alt_pattern}"
)
# Final attempt: try standard Python import (for backward compatibility)
logger.debug(
f"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'"
)
try:
import common_ops
logger.debug(f"[sgl_kernel] ✓ Successfully imported via standard Python import")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except ImportError as e:
previous_import_errors.append(e)
logger.debug(f"[sgl_kernel] ✗ Standard Python import failed: {e}")
attempt_error_msg = "\n".join(
f"- {type(err).__name__}: {err}" for err in previous_import_errors
)
# All attempts failed
error_msg = f"""
[sgl_kernel] CRITICAL: Could not load any common_ops library!
Attempted locations:
1. Architecture-specific pattern: {ops_pattern} - found files: {matching_files}
2. Fallback pattern: {alt_pattern} - found files: {alt_matching_files}
3. Standard Python import: common_ops - failed
GPU Info:
- Compute capability: {compute_capability}
- Expected variant: {variant_name}
Please ensure sgl_kernel is properly installed with:
pip install --upgrade sgl_kernel
Error details from previous import attempts:
{attempt_error_msg}
"""
logger.debug(error_msg)
raise ImportError(error_msg)
from sgl_kernel.load_utils import _load_architecture_specific_ops, _preload_cuda_library
# Initialize the ops library based on current GPU
logger.debug("[sgl_kernel] Initializing architecture-specific operator library...")
common_ops = _load_architecture_specific_ops()
logger.debug("[sgl_kernel] ✓ Operator library initialization complete")
# copy & modify from torch/utils/cpp_extension.py
def _find_cuda_home():
"""Find the CUDA install path."""
# Guess #1
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home is None:
# Guess #2
nvcc_path = shutil.which("nvcc")
if nvcc_path is not None:
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
else:
# Guess #3
cuda_home = "/usr/local/cuda"
return cuda_home
# Preload the CUDA library to avoid the issue of libcudart.so.12 not found
if torch.version.cuda is not None:
cuda_home = Path(_find_cuda_home())
if (cuda_home / "lib").is_dir():
cuda_path = cuda_home / "lib"
elif (cuda_home / "lib64").is_dir():
cuda_path = cuda_home / "lib64"
else:
# Search for 'libcudart.so.12' in subdirectories
for path in cuda_home.rglob("libcudart.so.12"):
cuda_path = path.parent
break
else:
raise RuntimeError("Could not find CUDA lib directory.")
_preload_cuda_library()
cuda_include = (cuda_path / "libcudart.so.12").resolve()
if cuda_include.exists():
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
from sgl_kernel.allreduce import *
from sgl_kernel.attention import (
......
import ctypes
import glob
import importlib.util
import logging
import os
import shutil
from pathlib import Path
from typing import List
import torch
logger = logging.getLogger(__name__)
def _get_compute_capability():
"""Get the compute capability of the current GPU."""
if not torch.cuda.is_available():
return None
# Get the current device
device = torch.cuda.current_device()
properties = torch.cuda.get_device_properties(device)
# Return as integer (major * 10 + minor)
return properties.major * 10 + properties.minor
def _filter_compiled_extensions(file_list):
"""Filter and prioritize compiled extensions over Python source files."""
compiled_extensions = [".so", ".pyd", ".dll"] # Common compiled extension suffixes
compiled_files = []
other_files = []
for file_path in file_list:
path = Path(file_path)
# Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so)
if any(
str(path).endswith(ext) or ext in str(path) for ext in compiled_extensions
):
compiled_files.append(file_path)
else:
other_files.append(file_path)
# Return compiled files first, then others
return compiled_files + other_files
def _load_architecture_specific_ops():
"""Load the appropriate common_ops library based on GPU architecture."""
compute_capability = _get_compute_capability()
logger.debug(
f"[sgl_kernel] GPU Detection: compute_capability = {compute_capability}"
)
# Get the directory where sgl_kernel is installed
sgl_kernel_dir = Path(__file__).parent
logger.debug(f"[sgl_kernel] sgl_kernel directory: {sgl_kernel_dir}")
# Determine which version to load based on GPU architecture
if compute_capability == 90:
ops_subdir = "sm90"
variant_name = "SM90 (Hopper/H100 with fast math optimization)"
elif compute_capability is not None:
ops_subdir = "sm100"
variant_name = f"SM{compute_capability} (precise math for compatibility)"
else:
ops_subdir = "sm100"
variant_name = "CPU/No GPU detected (using precise math)"
# Look for the compiled module with any valid extension
ops_pattern = str(sgl_kernel_dir / ops_subdir / "common_ops.*")
raw_matching_files = glob.glob(ops_pattern)
matching_files = _filter_compiled_extensions(raw_matching_files)
logger.debug(f"[sgl_kernel] Attempting to load {variant_name}")
logger.debug(f"[sgl_kernel] Looking for library matching pattern: {ops_pattern}")
logger.debug(f"[sgl_kernel] Found files: {raw_matching_files}")
logger.debug(f"[sgl_kernel] Prioritized files: {matching_files}")
previous_import_errors: List[Exception] = []
# Try to load from the architecture-specific directory
if matching_files:
ops_path = Path(matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found architecture-specific library: {ops_path}")
try:
# Load the module from specific path using importlib
spec = importlib.util.spec_from_file_location("common_ops", str(ops_path))
if spec is None:
raise ImportError(f"Could not create module spec for {ops_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {ops_path}")
logger.debug(f"[sgl_kernel] Loading module from {ops_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded {variant_name}")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
previous_import_errors.append(e)
logger.debug(
f"[sgl_kernel] ✗ Failed to load from {ops_path}: {type(e).__name__}: {e}"
)
# Continue to fallback
else:
logger.debug(
f"[sgl_kernel] ✗ Architecture-specific library not found matching pattern: {ops_pattern}"
)
# Try alternative directory (in case installation structure differs)
alt_pattern = str(sgl_kernel_dir / "common_ops.*")
raw_alt_files = glob.glob(alt_pattern)
alt_matching_files = _filter_compiled_extensions(raw_alt_files)
logger.debug(f"[sgl_kernel] Attempting fallback: looking for pattern {alt_pattern}")
logger.debug(f"[sgl_kernel] Found fallback files: {raw_alt_files}")
logger.debug(f"[sgl_kernel] Prioritized fallback files: {alt_matching_files}")
if alt_matching_files:
alt_path = Path(alt_matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found fallback library: {alt_path}")
try:
spec = importlib.util.spec_from_file_location("common_ops", str(alt_path))
if spec is None:
raise ImportError(f"Could not create module spec for {alt_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {alt_path}")
logger.debug(f"[sgl_kernel] Loading fallback module from {alt_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded fallback library")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
previous_import_errors.append(e)
logger.debug(
f"[sgl_kernel] ✗ Failed to load fallback from {alt_path}: {type(e).__name__}: {e}"
)
else:
logger.debug(
f"[sgl_kernel] ✗ Fallback library not found matching pattern: {alt_pattern}"
)
# Final attempt: try standard Python import (for backward compatibility)
logger.debug(
f"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'"
)
try:
import common_ops
logger.debug(f"[sgl_kernel] ✓ Successfully imported via standard Python import")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except ImportError as e:
previous_import_errors.append(e)
logger.debug(f"[sgl_kernel] ✗ Standard Python import failed: {e}")
attempt_error_msg = "\n".join(
f"- {type(err).__name__}: {err}" for err in previous_import_errors
)
# All attempts failed
error_msg = f"""
[sgl_kernel] CRITICAL: Could not load any common_ops library!
Attempted locations:
1. Architecture-specific pattern: {ops_pattern} - found files: {matching_files}
2. Fallback pattern: {alt_pattern} - found files: {alt_matching_files}
3. Standard Python import: common_ops - failed
GPU Info:
- Compute capability: {compute_capability}
- Expected variant: {variant_name}
Please ensure sgl_kernel is properly installed with:
pip install --upgrade sgl_kernel
Error details from previous import attempts:
{attempt_error_msg}
"""
logger.debug(error_msg)
raise ImportError(error_msg)
# copy & modify from torch/utils/cpp_extension.py
def _find_cuda_home():
"""Find the CUDA install path."""
# Guess #1
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home is None:
# Guess #2
nvcc_path = shutil.which("nvcc")
if nvcc_path is not None:
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
else:
# Guess #3
cuda_home = "/usr/local/cuda"
return cuda_home
def _preload_cuda_library():
cuda_home = Path(_find_cuda_home())
if (cuda_home / "lib").is_dir():
cuda_path = cuda_home / "lib"
elif (cuda_home / "lib64").is_dir():
cuda_path = cuda_home / "lib64"
else:
# Search for 'libcudart.so.12' in subdirectories
for path in cuda_home.rglob("libcudart.so.12"):
cuda_path = path.parent
break
else:
raise RuntimeError("Could not find CUDA lib directory.")
cuda_include = (cuda_path / "libcudart.so.12").resolve()
if cuda_include.exists():
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
......@@ -28,11 +28,29 @@ def moe_align_block_size(
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
gating_output: float,
gating_output: torch.Tensor,
renormalize: bool = False,
moe_softcapping: float = 0.0,
correction_bias: Optional[torch.Tensor] = None,
) -> None:
"""
Compute top-k softmax for MoE routing.
Args:
topk_weights: Output tensor for top-k weights [num_tokens, topk]
topk_ids: Output tensor for top-k expert indices [num_tokens, topk]
gating_output: Gating logits [num_tokens, num_experts]
renormalize: Whether to renormalize the top-k weights
moe_softcapping: Tanh softcapping value (0.0 to disable)
correction_bias: Per-expert bias correction [num_experts], must be float32 if provided
"""
torch.ops.sgl_kernel.topk_softmax.default(
topk_weights, topk_ids, gating_output, renormalize
topk_weights,
topk_ids,
gating_output,
renormalize,
moe_softcapping,
correction_bias,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment