Unverified Commit c0652d90 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up sgl kernel (#12413)


Co-authored-by: default avatarByron Hsu <byronhsu1230@gmail.com>
parent 2e48584b
...@@ -37,7 +37,11 @@ from pydantic import ( ...@@ -37,7 +37,11 @@ from pydantic import (
model_validator, model_validator,
) )
from typing_extensions import Literal from typing_extensions import Literal
from xgrammar import StructuralTag
try:
from xgrammar import StructuralTag
except:
StructuralTag = Any
from sglang.utils import convert_json_schema_to_str from sglang.utils import convert_json_schema_to_str
......
...@@ -42,7 +42,7 @@ endif() ...@@ -42,7 +42,7 @@ endif()
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
clear_cuda_arches(CMAKE_FLAG) clear_cuda_arches(CMAKE_FLAG)
# Third Party # Third Party repos
# cutlass # cutlass
FetchContent_Declare( FetchContent_Declare(
repo-cutlass repo-cutlass
...@@ -271,6 +271,8 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4) ...@@ -271,6 +271,8 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
) )
endif() endif()
# All source files
# NOTE: Please sort the filenames alphabetically
set(SOURCES set(SOURCES
"csrc/allreduce/custom_all_reduce.cu" "csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/mscclpp_allreduce.cu" "csrc/allreduce/mscclpp_allreduce.cu"
...@@ -279,16 +281,15 @@ set(SOURCES ...@@ -279,16 +281,15 @@ set(SOURCES
"csrc/attention/lightning_attention_decode_kernel.cu" "csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/merge_attn_states.cu" "csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu" "csrc/attention/vertical_slash_index.cu"
"csrc/common_extension.cc"
"csrc/elementwise/activation.cu" "csrc/elementwise/activation.cu"
"csrc/elementwise/cast.cu" "csrc/elementwise/cast.cu"
"csrc/elementwise/copy.cu"
"csrc/elementwise/concat_mla.cu" "csrc/elementwise/concat_mla.cu"
"csrc/elementwise/copy.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu" "csrc/elementwise/rope.cu"
"csrc/elementwise/topk.cu" "csrc/elementwise/topk.cu"
"csrc/common_extension.cc" "csrc/expert_specialization/es_fp8_blockwise.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/gemm/awq_kernel.cu" "csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu" "csrc/gemm/bmm_fp8.cu"
...@@ -314,10 +315,11 @@ set(SOURCES ...@@ -314,10 +315,11 @@ set(SOURCES
"csrc/gemm/marlin/gptq_marlin_repack.cu" "csrc/gemm/marlin/gptq_marlin_repack.cu"
"csrc/gemm/marlin/awq_marlin_repack.cu" "csrc/gemm/marlin/awq_marlin_repack.cu"
"csrc/gemm/gptq/gptq_kernel.cu" "csrc/gemm/gptq/gptq_kernel.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/mamba/causal_conv1d.cu" "csrc/mamba/causal_conv1d.cu"
"csrc/memory/store.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu" "csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
...@@ -332,16 +334,12 @@ set(SOURCES ...@@ -332,16 +334,12 @@ set(SOURCES
"csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu" "csrc/moe/prepare_moe_input.cu"
"csrc/memory/store.cu" "csrc/quantization/gguf/gguf_kernel.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu" "csrc/speculative/eagle_utils.cu"
"csrc/speculative/ngram_utils.cu" "csrc/speculative/ngram_utils.cu"
"csrc/speculative/packbit.cu" "csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu" "csrc/speculative/speculative_sampling.cu"
"csrc/expert_specialization/es_fp8_blockwise.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
...@@ -356,17 +354,7 @@ set(SOURCES ...@@ -356,17 +354,7 @@ set(SOURCES
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
) )
# =========================== Common SM90 Build ============================= # set(INCLUDES
# Build SM90 library with fast math optimization (same namespace, different directory)
Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_definitions(common_ops_sm90_build PRIVATE
USE_FAST_MATH=1
)
target_compile_options(common_ops_sm90_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math>
)
target_include_directories(common_ops_sm90_build PRIVATE
${repo-cutlass_SOURCE_DIR}/include ${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include ${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include ${repo-flashinfer_SOURCE_DIR}/include
...@@ -376,6 +364,15 @@ target_include_directories(common_ops_sm90_build PRIVATE ...@@ -376,6 +364,15 @@ target_include_directories(common_ops_sm90_build PRIVATE
${repo-cutlass_SOURCE_DIR}/examples/common ${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
) )
# =========================== Common SM90 Build ============================= #
# Build SM90 library with fast math optimization (same namespace, different directory)
Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_options(common_ops_sm90_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math>
)
target_include_directories(common_ops_sm90_build PRIVATE ${INCLUDES})
# Set output name and separate build directory to avoid conflicts # Set output name and separate build directory to avoid conflicts
set_target_properties(common_ops_sm90_build PROPERTIES set_target_properties(common_ops_sm90_build PROPERTIES
OUTPUT_NAME "common_ops" OUTPUT_NAME "common_ops"
...@@ -386,22 +383,10 @@ set_target_properties(common_ops_sm90_build PROPERTIES ...@@ -386,22 +383,10 @@ set_target_properties(common_ops_sm90_build PROPERTIES
# Build SM100+ library with precise math (same namespace, different directory) # Build SM100+ library with precise math (same namespace, different directory)
Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_definitions(common_ops_sm100_build PRIVATE
USE_FAST_MATH=0
)
target_compile_options(common_ops_sm100_build PRIVATE target_compile_options(common_ops_sm100_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}> $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>
) )
target_include_directories(common_ops_sm100_build PRIVATE target_include_directories(common_ops_sm100_build PRIVATE ${INCLUDES})
${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
)
# Set output name and separate build directory to avoid conflicts # Set output name and separate build directory to avoid conflicts
set_target_properties(common_ops_sm100_build PROPERTIES set_target_properties(common_ops_sm100_build PROPERTIES
OUTPUT_NAME "common_ops" OUTPUT_NAME "common_ops"
...@@ -432,6 +417,7 @@ add_subdirectory( ...@@ -432,6 +417,7 @@ add_subdirectory(
${repo-mscclpp_SOURCE_DIR} ${repo-mscclpp_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build ${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build
) )
target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
...@@ -453,7 +439,7 @@ target_compile_definitions(common_ops_sm100_build PRIVATE ...@@ -453,7 +439,7 @@ target_compile_definitions(common_ops_sm100_build PRIVATE
install(TARGETS common_ops_sm90_build LIBRARY DESTINATION sgl_kernel/sm90) install(TARGETS common_ops_sm90_build LIBRARY DESTINATION sgl_kernel/sm90)
install(TARGETS common_ops_sm100_build LIBRARY DESTINATION sgl_kernel/sm100) install(TARGETS common_ops_sm100_build LIBRARY DESTINATION sgl_kernel/sm100)
# ============================ Optional Install ============================= # # ============================ Optional Install: FA3 ============================= #
# set flash-attention sources file # set flash-attention sources file
# Now FA3 support sm80/sm86/sm90 # Now FA3 support sm80/sm86/sm90
if (SGL_KERNEL_ENABLE_FA3) if (SGL_KERNEL_ENABLE_FA3)
...@@ -553,10 +539,10 @@ target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERN ...@@ -553,10 +539,10 @@ target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERN
target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel) install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
# ============================ Extra Install ============================= # # ============================ Extra Install: FLashMLA ============================= #
include(${CMAKE_CURRENT_LIST_DIR}/cmake/flashmla.cmake) include(${CMAKE_CURRENT_LIST_DIR}/cmake/flashmla.cmake)
# ============================ DeepGEMM (JIT) ============================= # # ============================ Extra Install: DeepGEMM (JIT) ============================= #
# Create a separate library for DeepGEMM's Python API. # Create a separate library for DeepGEMM's Python API.
# This keeps its compilation isolated from the main common_ops. # This keeps its compilation isolated from the main common_ops.
set(DEEPGEMM_SOURCES set(DEEPGEMM_SOURCES
...@@ -601,13 +587,13 @@ install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/" ...@@ -601,13 +587,13 @@ install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/" install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/"
DESTINATION "deep_gemm/include/cutlass") DESTINATION "deep_gemm/include/cutlass")
# triton_kernels # ============================ Extra Install: triton kernels ============================= #
install(DIRECTORY "${repo-triton_SOURCE_DIR}/python/triton_kernels/triton_kernels/" install(DIRECTORY "${repo-triton_SOURCE_DIR}/python/triton_kernels/triton_kernels/"
DESTINATION "triton_kernels" DESTINATION "triton_kernels"
PATTERN ".git*" EXCLUDE PATTERN ".git*" EXCLUDE
PATTERN "__pycache__" EXCLUDE) PATTERN "__pycache__" EXCLUDE)
# flash attention 4 # ============================ Extra Install: FA4 ============================= #
# TODO: find a better install condition. # TODO: find a better install condition.
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
# flash_attn/cute # flash_attn/cute
......
...@@ -13,6 +13,8 @@ set(FLASHMLA_CUDA_FLAGS ...@@ -13,6 +13,8 @@ set(FLASHMLA_CUDA_FLAGS
"--expt-relaxed-constexpr" "--expt-relaxed-constexpr"
"--expt-extended-lambda" "--expt-extended-lambda"
"--use_fast_math" "--use_fast_math"
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
) )
# The FlashMLA kernels only work on hopper and require CUDA 12.4 or later. # The FlashMLA kernels only work on hopper and require CUDA 12.4 or later.
......
...@@ -118,37 +118,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -118,37 +118,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"topk_indices_offset) -> ()"); "topk_indices_offset) -> ()");
m.impl("fast_topk_transform_ragged_fused", torch::kCUDA, &fast_topk_transform_ragged_interface); m.impl("fast_topk_transform_ragged_fused", torch::kCUDA, &fast_topk_transform_ragged_interface);
/*
* From gguf quantiztion
*/
m.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
m.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
m.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
m.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
m.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
m.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
m.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
m.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
m.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
...@@ -256,7 +225,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -256,7 +225,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()"); "pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"); m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, float "
"moe_softcapping, Tensor? correction_bias) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def("moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()"); m.def("moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()");
...@@ -289,22 +260,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -289,22 +260,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum); m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
/*
* From csrc/moe/marlin_moe_wna16
*/
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/* /*
* From csrc/moe/cutlass_moe/w4a8 * From csrc/moe/cutlass_moe/w4a8
*/ */
...@@ -324,6 +279,22 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -324,6 +279,22 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" int chunk_size, int topk) -> ()"); " int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm); m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/*
* From csrc/moe/marlin_moe_wna16
*/
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
...@@ -521,6 +492,38 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -521,6 +492,38 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor _ascales, Tensor! _out_feats) -> ()"); "Tensor _ascales, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm); m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
/*
* From csrc/quantization/gguf
*/
m.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
m.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
m.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
m.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
m.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
m.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
m.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
m.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
m.def("ggml_moe_get_block_size(int type) -> int");
m.impl("ggml_moe_get_block_size", torch::kCUDA, &ggml_moe_get_block_size);
/* /*
* From csrc/mamba * From csrc/mamba
*/ */
...@@ -556,7 +559,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -556,7 +559,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm); m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
/* /*
* From hadamard-transform * From fast-hadamard-transform
*/ */
m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor"); m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform); m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform);
......
...@@ -70,7 +70,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -70,7 +70,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle); m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
// quick allreduce // quick allreduce
#ifdef USE_ROCM
m.def( m.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool " "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()"); "cast_bf2half) -> ()");
...@@ -86,7 +85,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -86,7 +85,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
// Max input size in bytes // Max input size in bytes
m.def("qr_max_size", &qr_max_size); m.def("qr_max_size", &qr_max_size);
#endif
/* /*
* From csrc/moe * From csrc/moe
...@@ -97,7 +95,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -97,7 +95,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()"); "pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"); m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, float "
"moe_softcapping, Tensor? correction_bias) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
/* /*
...@@ -116,12 +116,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -116,12 +116,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"()"); "()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
/*
* From XGrammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
...@@ -171,6 +165,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -171,6 +165,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, " "transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
"Tensor dst_indices, int page_size) ->() "); "Tensor dst_indices, int page_size) ->() ");
m.impl("transfer_kv_all_layer_direct_lf_pf", torch::kCUDA, &transfer_kv_all_layer_direct_lf_pf); m.impl("transfer_kv_all_layer_direct_lf_pf", torch::kCUDA, &transfer_kv_all_layer_direct_lf_pf);
/*
* From csrc/grammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
...@@ -73,8 +73,13 @@ __device__ float convert_to_float(T x) { ...@@ -73,8 +73,13 @@ __device__ float convert_to_float(T x) {
// We have our own implementation of softmax here so we can support transposing the output // We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing. // in the softmax kernel when we extend this module to support expert-choice routing.
template <typename T, int TPB> template <typename T, int TPB>
__launch_bounds__(TPB) __global__ __launch_bounds__(TPB) __global__ void moeSoftmax(
void moeSoftmax(const T* input, const bool* finished, float* output, const int num_cols) { const T* input,
const bool* finished,
float* output,
const int num_cols,
const float moe_softcapping,
const float* correction_bias) {
using BlockReduce = cub::BlockReduce<float, TPB>; using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage; __shared__ typename BlockReduce::TempStorage tmpStorage;
...@@ -90,9 +95,23 @@ __launch_bounds__(TPB) __global__ ...@@ -90,9 +95,23 @@ __launch_bounds__(TPB) __global__
return; return;
} }
// First pass: Apply transformation, find max, and write transformed values to output
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
threadData = max(convert_to_float<T>(input[idx]), threadData); float val = convert_to_float<T>(input[idx]);
// Apply tanh softcapping if enabled
if (moe_softcapping != 0.0f) {
val = tanhf(val / moe_softcapping) * moe_softcapping;
}
// Apply correction bias if provided
if (correction_bias != nullptr) {
val = val + correction_bias[ii];
}
output[idx] = val; // Store transformed value
threadData = max(val, threadData);
} }
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp()); const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
...@@ -102,11 +121,11 @@ __launch_bounds__(TPB) __global__ ...@@ -102,11 +121,11 @@ __launch_bounds__(TPB) __global__
} }
__syncthreads(); __syncthreads();
// Second pass: Compute sum using transformed values from output
threadData = 0; threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
threadData += exp((convert_to_float<T>(input[idx]) - float_max)); threadData += exp((output[idx] - float_max));
} }
const auto Z = BlockReduce(tmpStorage).Sum(threadData); const auto Z = BlockReduce(tmpStorage).Sum(threadData);
...@@ -116,10 +135,11 @@ __launch_bounds__(TPB) __global__ ...@@ -116,10 +135,11 @@ __launch_bounds__(TPB) __global__
} }
__syncthreads(); __syncthreads();
// Third pass: Compute final softmax using transformed values from output
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
const float val = exp((convert_to_float<T>(input[idx]) - float_max)) * normalizing_factor; const float softmax_val = exp((output[idx] - float_max)) * normalizing_factor;
output[idx] = val; output[idx] = softmax_val;
} }
} }
...@@ -216,7 +236,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( ...@@ -216,7 +236,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
const int k, const int k,
const int start_expert, const int start_expert,
const int end_expert, const int end_expert,
const bool renormalize) { const bool renormalize,
const float moe_softcapping,
const float* correction_bias) {
// We begin by enforcing compile time assertions and setting up compile time constants. // We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
...@@ -283,16 +305,48 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( ...@@ -283,16 +305,48 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_temp); AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_temp);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr); const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll #pragma unroll
// Note(Byron): interleaved loads to achieve better memory coalescing
// | thread[0] | thread[1] | thread[2] | thread[3] | thread[0] | thread[1] | thread[2] | thread[3] | ...
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
} }
float row_chunk[VPT]; float row_chunk[VPT];
#pragma unroll #pragma unroll
// Note(Byron): upcast logits to float32
for (int ii = 0; ii < VPT; ++ii) { for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = convert_to_float<T>(row_chunk_temp[ii]); row_chunk[ii] = convert_to_float<T>(row_chunk_temp[ii]);
} }
// Apply tanh softcapping and correction bias
if (moe_softcapping != 0.0f || correction_bias != nullptr) {
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
float val = row_chunk[ii];
// Apply tanh softcapping if enabled
if (moe_softcapping != 0.0f) {
val = tanhf(val / moe_softcapping) * moe_softcapping;
}
// Apply correction bias if provided
if (correction_bias != nullptr) {
/*
LDG is interleaved
|thread0 LDG| |thread1 LDG| |thread0 LDG| |thread1 LDG|
|--------- group0 --------| |----------group1 --------|
^ local2
*/
const int group_id = ii / ELTS_PER_LDG;
const int local_id = ii % ELTS_PER_LDG;
const int expert_idx = first_elt_read_by_thread + group_id * THREADS_PER_ROW * ELTS_PER_LDG + local_id;
val = val + correction_bias[expert_idx];
}
row_chunk[ii] = val;
}
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction. // convert to float afterwards for the exp + sum reduction.
float thread_max = row_chunk[0]; float thread_max = row_chunk[0];
...@@ -301,9 +355,15 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( ...@@ -301,9 +355,15 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
thread_max = max(thread_max, row_chunk[ii]); thread_max = max(thread_max, row_chunk[ii]);
} }
/*********************************/
/********* Softmax Begin *********/
/*********************************/
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. // Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
// lane id: 0-31 within a warp
#pragma unroll #pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
// butterfly reduce with (lane id ^ mask)
thread_max = max(thread_max, SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, thread_max, mask, THREADS_PER_ROW)); thread_max = max(thread_max, SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, thread_max, mask, THREADS_PER_ROW));
} }
...@@ -333,6 +393,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( ...@@ -333,6 +393,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
for (int ii = 0; ii < VPT; ++ii) { for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
} }
/*******************************/
/********* Softmax End *********/
/*******************************/
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index. // with the max index.
...@@ -438,6 +501,8 @@ void topkGatingSoftmaxLauncherHelper( ...@@ -438,6 +501,8 @@ void topkGatingSoftmaxLauncherHelper(
const int start_expert, const int start_expert,
const int end_expert, const int end_expert,
const bool renormalize, const bool renormalize,
const float moe_softcapping,
const float* correction_bias,
cudaStream_t stream) { cudaStream_t stream) {
static constexpr std::size_t MAX_BYTES_PER_LDG = 16; static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
...@@ -450,12 +515,33 @@ void topkGatingSoftmaxLauncherHelper( ...@@ -450,12 +515,33 @@ void topkGatingSoftmaxLauncherHelper(
dim3 block_dim(WARP_SIZE, WARPS_PER_TB); dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>( topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, k, start_expert, end_expert, renormalize); input,
finished,
output,
num_rows,
indices,
k,
start_expert,
end_expert,
renormalize,
moe_softcapping,
correction_bias);
} }
#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \ #define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \ topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream); gating_output, \
nullptr, \
topk_weights, \
topk_indices, \
num_tokens, \
topk, \
0, \
num_experts, \
renormalize, \
moe_softcapping, \
correction_bias, \
stream);
template <typename T> template <typename T>
void topkGatingSoftmaxKernelLauncher( void topkGatingSoftmaxKernelLauncher(
...@@ -467,6 +553,8 @@ void topkGatingSoftmaxKernelLauncher( ...@@ -467,6 +553,8 @@ void topkGatingSoftmaxKernelLauncher(
const int num_experts, const int num_experts,
const int topk, const int topk,
const bool renormalize, const bool renormalize,
const float moe_softcapping,
const float* correction_bias,
cudaStream_t stream) { cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4; static constexpr int WARPS_PER_TB = 4;
switch (num_experts) { switch (num_experts) {
...@@ -502,7 +590,8 @@ void topkGatingSoftmaxKernelLauncher( ...@@ -502,7 +590,8 @@ void topkGatingSoftmaxKernelLauncher(
softmax_workspace != nullptr, softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2."); "softmax_workspace must be provided for num_experts that are not a power of 2.");
static constexpr int TPB = 256; static constexpr int TPB = 256;
moeSoftmax<T, TPB><<<num_tokens, TPB, 0, stream>>>(gating_output, nullptr, softmax_workspace, num_experts); moeSoftmax<T, TPB><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts, moe_softcapping, correction_bias);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>( moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize); softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize);
} }
...@@ -512,9 +601,10 @@ void topkGatingSoftmaxKernelLauncher( ...@@ -512,9 +601,10 @@ void topkGatingSoftmaxKernelLauncher(
void topk_softmax( void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk] torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk] torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& gating_output, torch::Tensor& gating_output, // [num_tokens, num_experts]
const bool renormalize) // [num_tokens, num_experts] const bool renormalize,
{ const double moe_softcapping,
const c10::optional<torch::Tensor>& correction_bias) {
// Check data type // Check data type
TORCH_CHECK( TORCH_CHECK(
gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half || gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half ||
...@@ -552,6 +642,23 @@ void topk_softmax( ...@@ -552,6 +642,23 @@ void topk_softmax(
torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float)); torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float));
const at::ScalarType dtype = gating_output.scalar_type(); const at::ScalarType dtype = gating_output.scalar_type();
// Validate correction_bias if provided - must always be float32
const float* bias_ptr = nullptr;
if (correction_bias.has_value()) {
const torch::Tensor& bias_tensor = correction_bias.value();
TORCH_CHECK(bias_tensor.dim() == 1, "correction_bias must be 1D tensor [num_experts]");
TORCH_CHECK(bias_tensor.size(0) == num_experts, "correction_bias size must match num_experts");
TORCH_CHECK(
bias_tensor.scalar_type() == at::ScalarType::Float,
"correction_bias must be float32, got ",
bias_tensor.scalar_type());
bias_ptr = bias_tensor.data_ptr<float>();
}
// Cast moe_softcapping from double to float for CUDA kernels
const float moe_softcapping_f = static_cast<float>(moe_softcapping);
if (dtype == at::ScalarType::Float) { if (dtype == at::ScalarType::Float) {
topkGatingSoftmaxKernelLauncher<float>( topkGatingSoftmaxKernelLauncher<float>(
gating_output.data_ptr<float>(), gating_output.data_ptr<float>(),
...@@ -562,6 +669,8 @@ void topk_softmax( ...@@ -562,6 +669,8 @@ void topk_softmax(
num_experts, num_experts,
topk, topk,
renormalize, renormalize,
moe_softcapping_f,
bias_ptr,
stream); stream);
} else if (dtype == at::ScalarType::Half) { } else if (dtype == at::ScalarType::Half) {
topkGatingSoftmaxKernelLauncher<__half>( topkGatingSoftmaxKernelLauncher<__half>(
...@@ -573,6 +682,8 @@ void topk_softmax( ...@@ -573,6 +682,8 @@ void topk_softmax(
num_experts, num_experts,
topk, topk,
renormalize, renormalize,
moe_softcapping_f,
bias_ptr,
stream); stream);
} else if (dtype == at::ScalarType::BFloat16) { } else if (dtype == at::ScalarType::BFloat16) {
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>( topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
...@@ -584,6 +695,8 @@ void topk_softmax( ...@@ -584,6 +695,8 @@ void topk_softmax(
num_experts, num_experts,
topk, topk,
renormalize, renormalize,
moe_softcapping_f,
bias_ptr,
stream); stream);
} else { } else {
TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype); TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype);
......
...@@ -191,32 +191,6 @@ void fast_topk_transform_ragged_interface( ...@@ -191,32 +191,6 @@ void fast_topk_transform_ragged_interface(
void gelu_quick(at::Tensor& out, const at::Tensor& input); void gelu_quick(at::Tensor& out, const at::Tensor& input);
#endif #endif
/*
* From gguf quantization
*/
torch::Tensor
ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n, std::optional<at::ScalarType> const& dtype);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_moe_a8(
torch::Tensor X,
torch::Tensor W,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded,
int64_t type,
int64_t row,
int64_t top_k,
int64_t tokens);
torch::Tensor ggml_moe_a8_vec(
torch::Tensor X, torch::Tensor W, torch::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
...@@ -333,7 +307,12 @@ void moe_align_block_size( ...@@ -333,7 +307,12 @@ void moe_align_block_size(
bool pad_sorted_token_ids); bool pad_sorted_token_ids);
void topk_softmax( void topk_softmax(
torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize); torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& gating_output,
bool renormalize,
double moe_softcapping,
const c10::optional<torch::Tensor>& correction_bias);
void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor); void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor);
...@@ -417,6 +396,7 @@ void silu_and_mul_scaled_fp4_experts_quant( ...@@ -417,6 +396,7 @@ void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor const& input_global_scale, torch::Tensor const& input_global_scale,
torch::Tensor const& mask, torch::Tensor const& mask,
bool use_silu_and_mul); bool use_silu_and_mul);
/* /*
* From csrc/moe/cutlass_moe/w4a8 * From csrc/moe/cutlass_moe/w4a8
*/ */
...@@ -445,7 +425,9 @@ void cutlass_w4a8_moe_mm( ...@@ -445,7 +425,9 @@ void cutlass_w4a8_moe_mm(
torch::Tensor const& s_strides, torch::Tensor const& s_strides,
int64_t chunk_size, int64_t chunk_size,
int64_t topk); int64_t topk);
/*
* From csrc/moe/marlin_moe_wna16
*/
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none, std::optional<torch::Tensor> const& c_or_none,
...@@ -680,6 +662,11 @@ void transfer_kv_all_layer_direct_lf_pf( ...@@ -680,6 +662,11 @@ void transfer_kv_all_layer_direct_lf_pf(
const at::Tensor& dst_indices, const at::Tensor& dst_indices,
int64_t page_size); int64_t page_size);
/*
* From csrc/memory
*/
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v);
/* /*
* From FlashInfer * From FlashInfer
*/ */
...@@ -798,12 +785,12 @@ void convert_vertical_slash_indexes_mergehead( ...@@ -798,12 +785,12 @@ void convert_vertical_slash_indexes_mergehead(
bool causal); bool causal);
/* /*
* From XGrammar * From csrc/grammar
*/ */
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt); void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt);
/* /*
* From QServe * From csrc/gemm (QServe)
*/ */
void qserve_w4a8_per_chn_gemm( void qserve_w4a8_per_chn_gemm(
const torch::Tensor& _in_feats, const torch::Tensor& _in_feats,
...@@ -824,14 +811,35 @@ void qserve_w4a8_per_group_gemm( ...@@ -824,14 +811,35 @@ void qserve_w4a8_per_group_gemm(
torch::Tensor& _out_feats); torch::Tensor& _out_feats);
/* /*
* From csrc/spatial * From csrc/quantization/gguf
*/ */
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device); torch::Tensor
ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n, std::optional<at::ScalarType> const& dtype);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_moe_a8(
torch::Tensor X,
torch::Tensor W,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded,
int64_t type,
int64_t row,
int64_t top_k,
int64_t tokens);
torch::Tensor ggml_moe_a8_vec(
torch::Tensor X, torch::Tensor W, torch::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
/* /*
* From csrc/memory * From csrc/spatial
*/ */
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v); std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device);
/* /*
* From csrc/mamba * From csrc/mamba
...@@ -883,7 +891,7 @@ torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale); ...@@ -883,7 +891,7 @@ torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale); torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
/* /*
* From csrc/fastertransformer * From flashmla
*/ */
std::vector<at::Tensor> get_mla_decoding_metadata( std::vector<at::Tensor> get_mla_decoding_metadata(
at::Tensor& seqlens_k, at::Tensor& seqlens_k,
......
import ctypes
import logging
import os
import shutil
from pathlib import Path
from typing import List
import torch import torch
from sgl_kernel.load_utils import _load_architecture_specific_ops, _preload_cuda_library
logger = logging.getLogger(__name__)
def _get_compute_capability():
"""Get the compute capability of the current GPU."""
if not torch.cuda.is_available():
return None
# Get the current device
device = torch.cuda.current_device()
properties = torch.cuda.get_device_properties(device)
# Return as integer (major * 10 + minor)
return properties.major * 10 + properties.minor
def _filter_compiled_extensions(file_list):
"""Filter and prioritize compiled extensions over Python source files."""
compiled_extensions = [".so", ".pyd", ".dll"] # Common compiled extension suffixes
compiled_files = []
other_files = []
for file_path in file_list:
path = Path(file_path)
# Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so)
if any(
str(path).endswith(ext) or ext in str(path) for ext in compiled_extensions
):
compiled_files.append(file_path)
else:
other_files.append(file_path)
# Return compiled files first, then others
return compiled_files + other_files
def _load_architecture_specific_ops():
"""Load the appropriate common_ops library based on GPU architecture."""
import importlib.util
import sys
from pathlib import Path
compute_capability = _get_compute_capability()
logger.debug(
f"[sgl_kernel] GPU Detection: compute_capability = {compute_capability}"
)
# Get the directory where sgl_kernel is installed
sgl_kernel_dir = Path(__file__).parent
logger.debug(f"[sgl_kernel] sgl_kernel directory: {sgl_kernel_dir}")
# Determine which version to load based on GPU architecture
if compute_capability == 90:
ops_subdir = "sm90"
variant_name = "SM90 (Hopper/H100 with fast math optimization)"
elif compute_capability is not None:
ops_subdir = "sm100"
variant_name = f"SM{compute_capability} (precise math for compatibility)"
else:
ops_subdir = "sm100"
variant_name = "CPU/No GPU detected (using precise math)"
# Look for the compiled module with any valid extension
import glob
ops_pattern = str(sgl_kernel_dir / ops_subdir / "common_ops.*")
raw_matching_files = glob.glob(ops_pattern)
matching_files = _filter_compiled_extensions(raw_matching_files)
logger.debug(f"[sgl_kernel] Attempting to load {variant_name}")
logger.debug(f"[sgl_kernel] Looking for library matching pattern: {ops_pattern}")
logger.debug(f"[sgl_kernel] Found files: {raw_matching_files}")
logger.debug(f"[sgl_kernel] Prioritized files: {matching_files}")
previous_import_errors: List[Exception] = []
# Try to load from the architecture-specific directory
if matching_files:
ops_path = Path(matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found architecture-specific library: {ops_path}")
try:
# Load the module from specific path using importlib
spec = importlib.util.spec_from_file_location("common_ops", str(ops_path))
if spec is None:
raise ImportError(f"Could not create module spec for {ops_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {ops_path}")
logger.debug(f"[sgl_kernel] Loading module from {ops_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded {variant_name}")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
previous_import_errors.append(e)
logger.debug(
f"[sgl_kernel] ✗ Failed to load from {ops_path}: {type(e).__name__}: {e}"
)
# Continue to fallback
else:
logger.debug(
f"[sgl_kernel] ✗ Architecture-specific library not found matching pattern: {ops_pattern}"
)
# Try alternative directory (in case installation structure differs)
alt_pattern = str(sgl_kernel_dir / "common_ops.*")
raw_alt_files = glob.glob(alt_pattern)
alt_matching_files = _filter_compiled_extensions(raw_alt_files)
logger.debug(f"[sgl_kernel] Attempting fallback: looking for pattern {alt_pattern}")
logger.debug(f"[sgl_kernel] Found fallback files: {raw_alt_files}")
logger.debug(f"[sgl_kernel] Prioritized fallback files: {alt_matching_files}")
if alt_matching_files:
alt_path = Path(alt_matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found fallback library: {alt_path}")
try:
spec = importlib.util.spec_from_file_location("common_ops", str(alt_path))
if spec is None:
raise ImportError(f"Could not create module spec for {alt_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {alt_path}")
logger.debug(f"[sgl_kernel] Loading fallback module from {alt_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded fallback library")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
previous_import_errors.append(e)
logger.debug(
f"[sgl_kernel] ✗ Failed to load fallback from {alt_path}: {type(e).__name__}: {e}"
)
else:
logger.debug(
f"[sgl_kernel] ✗ Fallback library not found matching pattern: {alt_pattern}"
)
# Final attempt: try standard Python import (for backward compatibility)
logger.debug(
f"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'"
)
try:
import common_ops
logger.debug(f"[sgl_kernel] ✓ Successfully imported via standard Python import")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except ImportError as e:
previous_import_errors.append(e)
logger.debug(f"[sgl_kernel] ✗ Standard Python import failed: {e}")
attempt_error_msg = "\n".join(
f"- {type(err).__name__}: {err}" for err in previous_import_errors
)
# All attempts failed
error_msg = f"""
[sgl_kernel] CRITICAL: Could not load any common_ops library!
Attempted locations:
1. Architecture-specific pattern: {ops_pattern} - found files: {matching_files}
2. Fallback pattern: {alt_pattern} - found files: {alt_matching_files}
3. Standard Python import: common_ops - failed
GPU Info:
- Compute capability: {compute_capability}
- Expected variant: {variant_name}
Please ensure sgl_kernel is properly installed with:
pip install --upgrade sgl_kernel
Error details from previous import attempts:
{attempt_error_msg}
"""
logger.debug(error_msg)
raise ImportError(error_msg)
# Initialize the ops library based on current GPU # Initialize the ops library based on current GPU
logger.debug("[sgl_kernel] Initializing architecture-specific operator library...")
common_ops = _load_architecture_specific_ops() common_ops = _load_architecture_specific_ops()
logger.debug("[sgl_kernel] ✓ Operator library initialization complete")
# copy & modify from torch/utils/cpp_extension.py
def _find_cuda_home():
"""Find the CUDA install path."""
# Guess #1
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home is None:
# Guess #2
nvcc_path = shutil.which("nvcc")
if nvcc_path is not None:
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
else:
# Guess #3
cuda_home = "/usr/local/cuda"
return cuda_home
# Preload the CUDA library to avoid the issue of libcudart.so.12 not found
if torch.version.cuda is not None: if torch.version.cuda is not None:
cuda_home = Path(_find_cuda_home()) _preload_cuda_library()
if (cuda_home / "lib").is_dir():
cuda_path = cuda_home / "lib"
elif (cuda_home / "lib64").is_dir():
cuda_path = cuda_home / "lib64"
else:
# Search for 'libcudart.so.12' in subdirectories
for path in cuda_home.rglob("libcudart.so.12"):
cuda_path = path.parent
break
else:
raise RuntimeError("Could not find CUDA lib directory.")
cuda_include = (cuda_path / "libcudart.so.12").resolve()
if cuda_include.exists():
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
from sgl_kernel.allreduce import * from sgl_kernel.allreduce import *
from sgl_kernel.attention import ( from sgl_kernel.attention import (
......
import ctypes
import glob
import importlib.util
import logging
import os
import shutil
from pathlib import Path
from typing import List
import torch
logger = logging.getLogger(__name__)
def _get_compute_capability():
"""Get the compute capability of the current GPU."""
if not torch.cuda.is_available():
return None
# Get the current device
device = torch.cuda.current_device()
properties = torch.cuda.get_device_properties(device)
# Return as integer (major * 10 + minor)
return properties.major * 10 + properties.minor
def _filter_compiled_extensions(file_list):
"""Filter and prioritize compiled extensions over Python source files."""
compiled_extensions = [".so", ".pyd", ".dll"] # Common compiled extension suffixes
compiled_files = []
other_files = []
for file_path in file_list:
path = Path(file_path)
# Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so)
if any(
str(path).endswith(ext) or ext in str(path) for ext in compiled_extensions
):
compiled_files.append(file_path)
else:
other_files.append(file_path)
# Return compiled files first, then others
return compiled_files + other_files
def _load_architecture_specific_ops():
"""Load the appropriate common_ops library based on GPU architecture."""
compute_capability = _get_compute_capability()
logger.debug(
f"[sgl_kernel] GPU Detection: compute_capability = {compute_capability}"
)
# Get the directory where sgl_kernel is installed
sgl_kernel_dir = Path(__file__).parent
logger.debug(f"[sgl_kernel] sgl_kernel directory: {sgl_kernel_dir}")
# Determine which version to load based on GPU architecture
if compute_capability == 90:
ops_subdir = "sm90"
variant_name = "SM90 (Hopper/H100 with fast math optimization)"
elif compute_capability is not None:
ops_subdir = "sm100"
variant_name = f"SM{compute_capability} (precise math for compatibility)"
else:
ops_subdir = "sm100"
variant_name = "CPU/No GPU detected (using precise math)"
# Look for the compiled module with any valid extension
ops_pattern = str(sgl_kernel_dir / ops_subdir / "common_ops.*")
raw_matching_files = glob.glob(ops_pattern)
matching_files = _filter_compiled_extensions(raw_matching_files)
logger.debug(f"[sgl_kernel] Attempting to load {variant_name}")
logger.debug(f"[sgl_kernel] Looking for library matching pattern: {ops_pattern}")
logger.debug(f"[sgl_kernel] Found files: {raw_matching_files}")
logger.debug(f"[sgl_kernel] Prioritized files: {matching_files}")
previous_import_errors: List[Exception] = []
# Try to load from the architecture-specific directory
if matching_files:
ops_path = Path(matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found architecture-specific library: {ops_path}")
try:
# Load the module from specific path using importlib
spec = importlib.util.spec_from_file_location("common_ops", str(ops_path))
if spec is None:
raise ImportError(f"Could not create module spec for {ops_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {ops_path}")
logger.debug(f"[sgl_kernel] Loading module from {ops_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded {variant_name}")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
previous_import_errors.append(e)
logger.debug(
f"[sgl_kernel] ✗ Failed to load from {ops_path}: {type(e).__name__}: {e}"
)
# Continue to fallback
else:
logger.debug(
f"[sgl_kernel] ✗ Architecture-specific library not found matching pattern: {ops_pattern}"
)
# Try alternative directory (in case installation structure differs)
alt_pattern = str(sgl_kernel_dir / "common_ops.*")
raw_alt_files = glob.glob(alt_pattern)
alt_matching_files = _filter_compiled_extensions(raw_alt_files)
logger.debug(f"[sgl_kernel] Attempting fallback: looking for pattern {alt_pattern}")
logger.debug(f"[sgl_kernel] Found fallback files: {raw_alt_files}")
logger.debug(f"[sgl_kernel] Prioritized fallback files: {alt_matching_files}")
if alt_matching_files:
alt_path = Path(alt_matching_files[0]) # Use the first prioritized file
logger.debug(f"[sgl_kernel] Found fallback library: {alt_path}")
try:
spec = importlib.util.spec_from_file_location("common_ops", str(alt_path))
if spec is None:
raise ImportError(f"Could not create module spec for {alt_path}")
common_ops = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module spec has no loader for {alt_path}")
logger.debug(f"[sgl_kernel] Loading fallback module from {alt_path}...")
spec.loader.exec_module(common_ops)
logger.debug(f"[sgl_kernel] ✓ Successfully loaded fallback library")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except Exception as e:
previous_import_errors.append(e)
logger.debug(
f"[sgl_kernel] ✗ Failed to load fallback from {alt_path}: {type(e).__name__}: {e}"
)
else:
logger.debug(
f"[sgl_kernel] ✗ Fallback library not found matching pattern: {alt_pattern}"
)
# Final attempt: try standard Python import (for backward compatibility)
logger.debug(
f"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'"
)
try:
import common_ops
logger.debug(f"[sgl_kernel] ✓ Successfully imported via standard Python import")
logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}")
return common_ops
except ImportError as e:
previous_import_errors.append(e)
logger.debug(f"[sgl_kernel] ✗ Standard Python import failed: {e}")
attempt_error_msg = "\n".join(
f"- {type(err).__name__}: {err}" for err in previous_import_errors
)
# All attempts failed
error_msg = f"""
[sgl_kernel] CRITICAL: Could not load any common_ops library!
Attempted locations:
1. Architecture-specific pattern: {ops_pattern} - found files: {matching_files}
2. Fallback pattern: {alt_pattern} - found files: {alt_matching_files}
3. Standard Python import: common_ops - failed
GPU Info:
- Compute capability: {compute_capability}
- Expected variant: {variant_name}
Please ensure sgl_kernel is properly installed with:
pip install --upgrade sgl_kernel
Error details from previous import attempts:
{attempt_error_msg}
"""
logger.debug(error_msg)
raise ImportError(error_msg)
# copy & modify from torch/utils/cpp_extension.py
def _find_cuda_home():
"""Find the CUDA install path."""
# Guess #1
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home is None:
# Guess #2
nvcc_path = shutil.which("nvcc")
if nvcc_path is not None:
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
else:
# Guess #3
cuda_home = "/usr/local/cuda"
return cuda_home
def _preload_cuda_library():
cuda_home = Path(_find_cuda_home())
if (cuda_home / "lib").is_dir():
cuda_path = cuda_home / "lib"
elif (cuda_home / "lib64").is_dir():
cuda_path = cuda_home / "lib64"
else:
# Search for 'libcudart.so.12' in subdirectories
for path in cuda_home.rglob("libcudart.so.12"):
cuda_path = path.parent
break
else:
raise RuntimeError("Could not find CUDA lib directory.")
cuda_include = (cuda_path / "libcudart.so.12").resolve()
if cuda_include.exists():
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
...@@ -28,11 +28,29 @@ def moe_align_block_size( ...@@ -28,11 +28,29 @@ def moe_align_block_size(
def topk_softmax( def topk_softmax(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
gating_output: float, gating_output: torch.Tensor,
renormalize: bool = False, renormalize: bool = False,
moe_softcapping: float = 0.0,
correction_bias: Optional[torch.Tensor] = None,
) -> None: ) -> None:
"""
Compute top-k softmax for MoE routing.
Args:
topk_weights: Output tensor for top-k weights [num_tokens, topk]
topk_ids: Output tensor for top-k expert indices [num_tokens, topk]
gating_output: Gating logits [num_tokens, num_experts]
renormalize: Whether to renormalize the top-k weights
moe_softcapping: Tanh softcapping value (0.0 to disable)
correction_bias: Per-expert bias correction [num_experts], must be float32 if provided
"""
torch.ops.sgl_kernel.topk_softmax.default( torch.ops.sgl_kernel.topk_softmax.default(
topk_weights, topk_ids, gating_output, renormalize topk_weights,
topk_ids,
gating_output,
renormalize,
moe_softcapping,
correction_bias,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment