Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c0652d90
Unverified
Commit
c0652d90
authored
Oct 31, 2025
by
Lianmin Zheng
Committed by
GitHub
Oct 31, 2025
Browse files
Clean up sgl kernel (#12413)
Co-authored-by:
Byron Hsu
<
byronhsu1230@gmail.com
>
parent
2e48584b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
513 additions
and
378 deletions
+513
-378
python/sglang/srt/entrypoints/openai/protocol.py
python/sglang/srt/entrypoints/openai/protocol.py
+5
-1
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+26
-40
sgl-kernel/cmake/flashmla.cmake
sgl-kernel/cmake/flashmla.cmake
+2
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+52
-49
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+9
-9
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
+129
-16
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+43
-35
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+3
-226
sgl-kernel/python/sgl_kernel/load_utils.py
sgl-kernel/python/sgl_kernel/load_utils.py
+224
-0
sgl-kernel/python/sgl_kernel/moe.py
sgl-kernel/python/sgl_kernel/moe.py
+20
-2
No files found.
python/sglang/srt/entrypoints/openai/protocol.py
View file @
c0652d90
...
...
@@ -37,7 +37,11 @@ from pydantic import (
model_validator
,
)
from
typing_extensions
import
Literal
from
xgrammar
import
StructuralTag
try
:
from
xgrammar
import
StructuralTag
except
:
StructuralTag
=
Any
from
sglang.utils
import
convert_json_schema_to_str
...
...
sgl-kernel/CMakeLists.txt
View file @
c0652d90
...
...
@@ -42,7 +42,7 @@ endif()
find_package
(
Torch REQUIRED
)
clear_cuda_arches
(
CMAKE_FLAG
)
# Third Party
# Third Party
repos
# cutlass
FetchContent_Declare
(
repo-cutlass
...
...
@@ -271,6 +271,8 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
)
endif
()
# All source files
# NOTE: Please sort the filenames alphabetically
set
(
SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/mscclpp_allreduce.cu"
...
...
@@ -279,16 +281,15 @@ set(SOURCES
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/common_extension.cc"
"csrc/elementwise/activation.cu"
"csrc/elementwise/cast.cu"
"csrc/elementwise/copy.cu"
"csrc/elementwise/concat_mla.cu"
"csrc/elementwise/copy.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu"
"csrc/elementwise/topk.cu"
"csrc/common_extension.cc"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/expert_specialization/es_fp8_blockwise.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
...
...
@@ -314,10 +315,11 @@ set(SOURCES
"csrc/gemm/marlin/gptq_marlin_repack.cu"
"csrc/gemm/marlin/awq_marlin_repack.cu"
"csrc/gemm/gptq/gptq_kernel.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/mamba/causal_conv1d.cu"
"csrc/memory/store.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
...
...
@@ -332,16 +334,12 @@ set(SOURCES
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/memory/store.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/ngram_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/expert_specialization/es_fp8_blockwise.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/norm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/renorm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/sampling.cu"
...
...
@@ -356,17 +354,7 @@ set(SOURCES
"
${
repo-flash-attention_SOURCE_DIR
}
/csrc/flash_attn/flash_sparse_api.cpp"
)
# =========================== Common SM90 Build ============================= #
# Build SM90 library with fast math optimization (same namespace, different directory)
Python_add_library
(
common_ops_sm90_build MODULE USE_SABI
${
SKBUILD_SABI_VERSION
}
WITH_SOABI
${
SOURCES
}
)
target_compile_definitions
(
common_ops_sm90_build PRIVATE
USE_FAST_MATH=1
)
target_compile_options
(
common_ops_sm90_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:
${
SGL_KERNEL_CUDA_FLAGS
}
-use_fast_math>
)
target_include_directories
(
common_ops_sm90_build PRIVATE
set
(
INCLUDES
${
repo-cutlass_SOURCE_DIR
}
/include
${
repo-cutlass_SOURCE_DIR
}
/tools/util/include
${
repo-flashinfer_SOURCE_DIR
}
/include
...
...
@@ -376,6 +364,15 @@ target_include_directories(common_ops_sm90_build PRIVATE
${
repo-cutlass_SOURCE_DIR
}
/examples/common
${
repo-flash-attention_SOURCE_DIR
}
/csrc/flash_attn/src
)
# =========================== Common SM90 Build ============================= #
# Build SM90 library with fast math optimization (same namespace, different directory)
Python_add_library
(
common_ops_sm90_build MODULE USE_SABI
${
SKBUILD_SABI_VERSION
}
WITH_SOABI
${
SOURCES
}
)
target_compile_options
(
common_ops_sm90_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:
${
SGL_KERNEL_CUDA_FLAGS
}
-use_fast_math>
)
target_include_directories
(
common_ops_sm90_build PRIVATE
${
INCLUDES
}
)
# Set output name and separate build directory to avoid conflicts
set_target_properties
(
common_ops_sm90_build PROPERTIES
OUTPUT_NAME
"common_ops"
...
...
@@ -386,22 +383,10 @@ set_target_properties(common_ops_sm90_build PROPERTIES
# Build SM100+ library with precise math (same namespace, different directory)
Python_add_library
(
common_ops_sm100_build MODULE USE_SABI
${
SKBUILD_SABI_VERSION
}
WITH_SOABI
${
SOURCES
}
)
target_compile_definitions
(
common_ops_sm100_build PRIVATE
USE_FAST_MATH=0
)
target_compile_options
(
common_ops_sm100_build PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:
${
SGL_KERNEL_CUDA_FLAGS
}
>
)
target_include_directories
(
common_ops_sm100_build PRIVATE
${
repo-cutlass_SOURCE_DIR
}
/include
${
repo-cutlass_SOURCE_DIR
}
/tools/util/include
${
repo-flashinfer_SOURCE_DIR
}
/include
${
repo-flashinfer_SOURCE_DIR
}
/csrc
${
repo-mscclpp_SOURCE_DIR
}
/include
${
repo-cutlass_SOURCE_DIR
}
/examples/77_blackwell_fmha
${
repo-cutlass_SOURCE_DIR
}
/examples/common
${
repo-flash-attention_SOURCE_DIR
}
/csrc/flash_attn/src
)
target_include_directories
(
common_ops_sm100_build PRIVATE
${
INCLUDES
}
)
# Set output name and separate build directory to avoid conflicts
set_target_properties
(
common_ops_sm100_build PROPERTIES
OUTPUT_NAME
"common_ops"
...
...
@@ -432,6 +417,7 @@ add_subdirectory(
${
repo-mscclpp_SOURCE_DIR
}
${
CMAKE_CURRENT_BINARY_DIR
}
/mscclpp-build
)
target_link_libraries
(
common_ops_sm90_build PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda cublas cublasLt mscclpp_static
)
target_link_libraries
(
common_ops_sm100_build PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda cublas cublasLt mscclpp_static
)
...
...
@@ -453,7 +439,7 @@ target_compile_definitions(common_ops_sm100_build PRIVATE
install
(
TARGETS common_ops_sm90_build LIBRARY DESTINATION sgl_kernel/sm90
)
install
(
TARGETS common_ops_sm100_build LIBRARY DESTINATION sgl_kernel/sm100
)
# ============================ Optional Install ============================= #
# ============================ Optional Install
: FA3
============================= #
# set flash-attention sources file
# Now FA3 support sm80/sm86/sm90
if
(
SGL_KERNEL_ENABLE_FA3
)
...
...
@@ -553,10 +539,10 @@ target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERN
target_link_libraries
(
spatial_ops PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda
)
install
(
TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel
)
# ============================ Extra Install ============================= #
# ============================ Extra Install
: FLashMLA
============================= #
include
(
${
CMAKE_CURRENT_LIST_DIR
}
/cmake/flashmla.cmake
)
# ============================ DeepGEMM (JIT) ============================= #
# ============================
Extra Install:
DeepGEMM (JIT) ============================= #
# Create a separate library for DeepGEMM's Python API.
# This keeps its compilation isolated from the main common_ops.
set
(
DEEPGEMM_SOURCES
...
...
@@ -601,13 +587,13 @@ install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
install
(
DIRECTORY
"
${
repo-cutlass_SOURCE_DIR
}
/include/cutlass/"
DESTINATION
"deep_gemm/include/cutlass"
)
# triton
_
kernels
#
============================ Extra Install:
triton
kernels
============================= #
install
(
DIRECTORY
"
${
repo-triton_SOURCE_DIR
}
/python/triton_kernels/triton_kernels/"
DESTINATION
"triton_kernels"
PATTERN
".git*"
EXCLUDE
PATTERN
"__pycache__"
EXCLUDE
)
#
flash attention 4
#
============================ Extra Install: FA4 ============================= #
# TODO: find a better install condition.
if
(
"
${
CUDA_VERSION
}
"
VERSION_GREATER_EQUAL
"12.8"
OR SGL_KERNEL_ENABLE_SM100A
)
# flash_attn/cute
...
...
sgl-kernel/cmake/flashmla.cmake
View file @
c0652d90
...
...
@@ -13,6 +13,8 @@ set(FLASHMLA_CUDA_FLAGS
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
"--use_fast_math"
"-Xcudafe=--diag_suppress=177"
# variable was declared but never referenced
)
# The FlashMLA kernels only work on hopper and require CUDA 12.4 or later.
...
...
sgl-kernel/csrc/common_extension.cc
View file @
c0652d90
...
...
@@ -118,37 +118,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"topk_indices_offset) -> ()"
);
m
.
impl
(
"fast_topk_transform_ragged_fused"
,
torch
::
kCUDA
,
&
fast_topk_transform_ragged_interface
);
/*
* From gguf quantiztion
*/
m
.
def
(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor"
);
m
.
impl
(
"ggml_dequantize"
,
torch
::
kCUDA
,
&
ggml_dequantize
);
m
.
def
(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor"
);
m
.
impl
(
"ggml_mul_mat_vec_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_vec_a8
);
m
.
def
(
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"
);
m
.
impl
(
"ggml_mul_mat_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_a8
);
m
.
def
(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"
);
m
.
impl
(
"ggml_moe_a8"
,
torch
::
kCUDA
,
&
ggml_moe_a8
);
m
.
def
(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor"
);
m
.
impl
(
"ggml_moe_a8_vec"
,
torch
::
kCUDA
,
&
ggml_moe_a8_vec
);
m
.
def
(
"ggml_moe_get_block_size"
,
&
ggml_moe_get_block_size
);
/*
* From csrc/gemm
*/
...
...
@@ -256,7 +225,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"
);
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, float "
"moe_softcapping, Tensor? correction_bias) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
def
(
"moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()"
);
...
...
@@ -289,22 +260,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"
);
m
.
impl
(
"apply_shuffle_mul_sum"
,
torch
::
kCUDA
,
&
apply_shuffle_mul_sum
);
/*
* From csrc/moe/marlin_moe_wna16
*/
m
.
def
(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"
);
m
.
impl
(
"moe_wna16_marlin_gemm"
,
torch
::
kCUDA
,
&
moe_wna16_marlin_gemm
);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
...
...
@@ -324,6 +279,22 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" int chunk_size, int topk) -> ()"
);
m
.
impl
(
"cutlass_w4a8_moe_mm"
,
torch
::
kCUDA
,
&
cutlass_w4a8_moe_mm
);
/*
* From csrc/moe/marlin_moe_wna16
*/
m
.
def
(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"
);
m
.
impl
(
"moe_wna16_marlin_gemm"
,
torch
::
kCUDA
,
&
moe_wna16_marlin_gemm
);
/*
* From csrc/speculative
*/
...
...
@@ -521,6 +492,38 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor _ascales, Tensor! _out_feats) -> ()"
);
m
.
impl
(
"qserve_w4a8_per_group_gemm"
,
torch
::
kCUDA
,
&
qserve_w4a8_per_group_gemm
);
/*
* From csrc/quantization/gguf
*/
m
.
def
(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor"
);
m
.
impl
(
"ggml_dequantize"
,
torch
::
kCUDA
,
&
ggml_dequantize
);
m
.
def
(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor"
);
m
.
impl
(
"ggml_mul_mat_vec_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_vec_a8
);
m
.
def
(
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"
);
m
.
impl
(
"ggml_mul_mat_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_a8
);
m
.
def
(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"
);
m
.
impl
(
"ggml_moe_a8"
,
torch
::
kCUDA
,
&
ggml_moe_a8
);
m
.
def
(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor"
);
m
.
impl
(
"ggml_moe_a8_vec"
,
torch
::
kCUDA
,
&
ggml_moe_a8_vec
);
m
.
def
(
"ggml_moe_get_block_size(int type) -> int"
);
m
.
impl
(
"ggml_moe_get_block_size"
,
torch
::
kCUDA
,
&
ggml_moe_get_block_size
);
/*
* From csrc/mamba
*/
...
...
@@ -556,7 +559,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"es_fp8_blockwise_scaled_grouped_mm"
,
&
es_fp8_blockwise_scaled_grouped_mm
);
/*
* From hadamard-transform
* From
fast-
hadamard-transform
*/
m
.
def
(
"fast_hadamard_transform(Tensor x, float scale) -> Tensor"
);
m
.
impl
(
"fast_hadamard_transform"
,
torch
::
kCUDA
,
&
fast_hadamard_transform
);
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
c0652d90
...
...
@@ -70,7 +70,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m
.
impl
(
"get_meta_buffer_ipc_handle"
,
torch
::
kCPU
,
&
get_meta_buffer_ipc_handle
);
// quick allreduce
#ifdef USE_ROCM
m
.
def
(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()"
);
...
...
@@ -86,7 +85,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
// Max input size in bytes
m
.
def
(
"qr_max_size"
,
&
qr_max_size
);
#endif
/*
* From csrc/moe
...
...
@@ -97,7 +95,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"
);
m
.
def
(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, float "
"moe_softcapping, Tensor? correction_bias) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
/*
...
...
@@ -116,12 +116,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"()"
);
m
.
impl
(
"build_tree_kernel_efficient"
,
torch
::
kCUDA
,
&
build_tree_kernel_efficient
);
/*
* From XGrammar
*/
m
.
def
(
"apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()"
);
m
.
impl
(
"apply_token_bitmask_inplace_cuda"
,
&
ApplyTokenBitmaskInplace
);
/*
* From csrc/kvcacheio
*/
...
...
@@ -171,6 +165,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
"Tensor dst_indices, int page_size) ->() "
);
m
.
impl
(
"transfer_kv_all_layer_direct_lf_pf"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_direct_lf_pf
);
/*
* From csrc/grammar
*/
m
.
def
(
"apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()"
);
m
.
impl
(
"apply_token_bitmask_inplace_cuda"
,
&
ApplyTokenBitmaskInplace
);
}
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
View file @
c0652d90
...
...
@@ -73,8 +73,13 @@ __device__ float convert_to_float(T x) {
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template
<
typename
T
,
int
TPB
>
__launch_bounds__
(
TPB
)
__global__
void
moeSoftmax
(
const
T
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_cols
)
{
__launch_bounds__
(
TPB
)
__global__
void
moeSoftmax
(
const
T
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_cols
,
const
float
moe_softcapping
,
const
float
*
correction_bias
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmpStorage
;
...
...
@@ -90,9 +95,23 @@ __launch_bounds__(TPB) __global__
return
;
}
// First pass: Apply transformation, find max, and write transformed values to output
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
threadData
=
max
(
convert_to_float
<
T
>
(
input
[
idx
]),
threadData
);
float
val
=
convert_to_float
<
T
>
(
input
[
idx
]);
// Apply tanh softcapping if enabled
if
(
moe_softcapping
!=
0.0
f
)
{
val
=
tanhf
(
val
/
moe_softcapping
)
*
moe_softcapping
;
}
// Apply correction bias if provided
if
(
correction_bias
!=
nullptr
)
{
val
=
val
+
correction_bias
[
ii
];
}
output
[
idx
]
=
val
;
// Store transformed value
threadData
=
max
(
val
,
threadData
);
}
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
MaxReduceOp
());
...
...
@@ -102,11 +121,11 @@ __launch_bounds__(TPB) __global__
}
__syncthreads
();
// Second pass: Compute sum using transformed values from output
threadData
=
0
;
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
threadData
+=
exp
((
convert_to_float
<
T
>
(
in
put
[
idx
]
)
-
float_max
));
threadData
+=
exp
((
out
put
[
idx
]
-
float_max
));
}
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Sum
(
threadData
);
...
...
@@ -116,10 +135,11 @@ __launch_bounds__(TPB) __global__
}
__syncthreads
();
// Third pass: Compute final softmax using transformed values from output
for
(
int
ii
=
threadIdx
.
x
;
ii
<
num_cols
;
ii
+=
TPB
)
{
const
int
idx
=
thread_row_offset
+
ii
;
const
float
val
=
exp
((
convert_to_float
<
T
>
(
in
put
[
idx
]
)
-
float_max
))
*
normalizing_factor
;
output
[
idx
]
=
val
;
const
float
softmax_
val
=
exp
((
out
put
[
idx
]
-
float_max
))
*
normalizing_factor
;
output
[
idx
]
=
softmax_
val
;
}
}
...
...
@@ -216,7 +236,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
const
bool
renormalize
)
{
const
bool
renormalize
,
const
float
moe_softcapping
,
const
float
*
correction_bias
)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert
(
VPT
==
(
VPT
&
-
VPT
),
"VPT must be power of 2"
);
static_assert
(
NUM_EXPERTS
==
(
NUM_EXPERTS
&
-
NUM_EXPERTS
),
"NUM_EXPERTS must be power of 2"
);
...
...
@@ -283,16 +305,48 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
AccessType
*
row_chunk_vec_ptr
=
reinterpret_cast
<
AccessType
*>
(
&
row_chunk_temp
);
const
AccessType
*
vec_thread_read_ptr
=
reinterpret_cast
<
const
AccessType
*>
(
thread_read_ptr
);
#pragma unroll
// Note(Byron): interleaved loads to achieve better memory coalescing
// | thread[0] | thread[1] | thread[2] | thread[3] | thread[0] | thread[1] | thread[2] | thread[3] | ...
for
(
int
ii
=
0
;
ii
<
LDG_PER_THREAD
;
++
ii
)
{
row_chunk_vec_ptr
[
ii
]
=
vec_thread_read_ptr
[
ii
*
THREADS_PER_ROW
];
}
float
row_chunk
[
VPT
];
#pragma unroll
// Note(Byron): upcast logits to float32
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
convert_to_float
<
T
>
(
row_chunk_temp
[
ii
]);
}
// Apply tanh softcapping and correction bias
if
(
moe_softcapping
!=
0.0
f
||
correction_bias
!=
nullptr
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
float
val
=
row_chunk
[
ii
];
// Apply tanh softcapping if enabled
if
(
moe_softcapping
!=
0.0
f
)
{
val
=
tanhf
(
val
/
moe_softcapping
)
*
moe_softcapping
;
}
// Apply correction bias if provided
if
(
correction_bias
!=
nullptr
)
{
/*
LDG is interleaved
|thread0 LDG| |thread1 LDG| |thread0 LDG| |thread1 LDG|
|--------- group0 --------| |----------group1 --------|
^ local2
*/
const
int
group_id
=
ii
/
ELTS_PER_LDG
;
const
int
local_id
=
ii
%
ELTS_PER_LDG
;
const
int
expert_idx
=
first_elt_read_by_thread
+
group_id
*
THREADS_PER_ROW
*
ELTS_PER_LDG
+
local_id
;
val
=
val
+
correction_bias
[
expert_idx
];
}
row_chunk
[
ii
]
=
val
;
}
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
float
thread_max
=
row_chunk
[
0
];
...
...
@@ -301,9 +355,15 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
thread_max
=
max
(
thread_max
,
row_chunk
[
ii
]);
}
/*********************************/
/********* Softmax Begin *********/
/*********************************/
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
// lane id: 0-31 within a warp
#pragma unroll
for
(
int
mask
=
THREADS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
// butterfly reduce with (lane id ^ mask)
thread_max
=
max
(
thread_max
,
SGLANG_SHFL_XOR_SYNC_WIDTH
(
0xffffffff
,
thread_max
,
mask
,
THREADS_PER_ROW
));
}
...
...
@@ -333,6 +393,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
for
(
int
ii
=
0
;
ii
<
VPT
;
++
ii
)
{
row_chunk
[
ii
]
=
row_chunk
[
ii
]
*
reciprocal_row_sum
;
}
/*******************************/
/********* Softmax End *********/
/*******************************/
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index.
...
...
@@ -438,6 +501,8 @@ void topkGatingSoftmaxLauncherHelper(
const
int
start_expert
,
const
int
end_expert
,
const
bool
renormalize
,
const
float
moe_softcapping
,
const
float
*
correction_bias
,
cudaStream_t
stream
)
{
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
...
...
@@ -450,12 +515,33 @@ void topkGatingSoftmaxLauncherHelper(
dim3
block_dim
(
WARP_SIZE
,
WARPS_PER_TB
);
topkGatingSoftmax
<
T
,
VPT
,
EXPERTS
,
WARPS_PER_TB
,
BYTES_PER_LDG
><<<
num_blocks
,
block_dim
,
0
,
stream
>>>
(
input
,
finished
,
output
,
num_rows
,
indices
,
k
,
start_expert
,
end_expert
,
renormalize
);
input
,
finished
,
output
,
num_rows
,
indices
,
k
,
start_expert
,
end_expert
,
renormalize
,
moe_softcapping
,
correction_bias
);
}
#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream);
gating_output, \
nullptr, \
topk_weights, \
topk_indices, \
num_tokens, \
topk, \
0, \
num_experts, \
renormalize, \
moe_softcapping, \
correction_bias, \
stream);
template
<
typename
T
>
void
topkGatingSoftmaxKernelLauncher
(
...
...
@@ -467,6 +553,8 @@ void topkGatingSoftmaxKernelLauncher(
const
int
num_experts
,
const
int
topk
,
const
bool
renormalize
,
const
float
moe_softcapping
,
const
float
*
correction_bias
,
cudaStream_t
stream
)
{
static
constexpr
int
WARPS_PER_TB
=
4
;
switch
(
num_experts
)
{
...
...
@@ -502,7 +590,8 @@ void topkGatingSoftmaxKernelLauncher(
softmax_workspace
!=
nullptr
,
"softmax_workspace must be provided for num_experts that are not a power of 2."
);
static
constexpr
int
TPB
=
256
;
moeSoftmax
<
T
,
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
softmax_workspace
,
num_experts
);
moeSoftmax
<
T
,
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
gating_output
,
nullptr
,
softmax_workspace
,
num_experts
,
moe_softcapping
,
correction_bias
);
moeTopK
<
TPB
><<<
num_tokens
,
TPB
,
0
,
stream
>>>
(
softmax_workspace
,
nullptr
,
topk_weights
,
topk_indices
,
num_experts
,
topk
,
0
,
num_experts
,
renormalize
);
}
...
...
@@ -510,11 +599,12 @@ void topkGatingSoftmaxKernelLauncher(
}
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
// [num_tokens, topk]
torch
::
Tensor
&
topk_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
gating_output
,
const
bool
renormalize
)
// [num_tokens, num_experts]
{
torch
::
Tensor
&
topk_weights
,
// [num_tokens, topk]
torch
::
Tensor
&
topk_indices
,
// [num_tokens, topk]
torch
::
Tensor
&
gating_output
,
// [num_tokens, num_experts]
const
bool
renormalize
,
const
double
moe_softcapping
,
const
c10
::
optional
<
torch
::
Tensor
>&
correction_bias
)
{
// Check data type
TORCH_CHECK
(
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Float
||
gating_output
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
...
...
@@ -552,6 +642,23 @@ void topk_softmax(
torch
::
empty
({
workspace_size
},
gating_output
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
const
at
::
ScalarType
dtype
=
gating_output
.
scalar_type
();
// Validate correction_bias if provided - must always be float32
const
float
*
bias_ptr
=
nullptr
;
if
(
correction_bias
.
has_value
())
{
const
torch
::
Tensor
&
bias_tensor
=
correction_bias
.
value
();
TORCH_CHECK
(
bias_tensor
.
dim
()
==
1
,
"correction_bias must be 1D tensor [num_experts]"
);
TORCH_CHECK
(
bias_tensor
.
size
(
0
)
==
num_experts
,
"correction_bias size must match num_experts"
);
TORCH_CHECK
(
bias_tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"correction_bias must be float32, got "
,
bias_tensor
.
scalar_type
());
bias_ptr
=
bias_tensor
.
data_ptr
<
float
>
();
}
// Cast moe_softcapping from double to float for CUDA kernels
const
float
moe_softcapping_f
=
static_cast
<
float
>
(
moe_softcapping
);
if
(
dtype
==
at
::
ScalarType
::
Float
)
{
topkGatingSoftmaxKernelLauncher
<
float
>
(
gating_output
.
data_ptr
<
float
>
(),
...
...
@@ -562,6 +669,8 @@ void topk_softmax(
num_experts
,
topk
,
renormalize
,
moe_softcapping_f
,
bias_ptr
,
stream
);
}
else
if
(
dtype
==
at
::
ScalarType
::
Half
)
{
topkGatingSoftmaxKernelLauncher
<
__half
>
(
...
...
@@ -573,6 +682,8 @@ void topk_softmax(
num_experts
,
topk
,
renormalize
,
moe_softcapping_f
,
bias_ptr
,
stream
);
}
else
if
(
dtype
==
at
::
ScalarType
::
BFloat16
)
{
topkGatingSoftmaxKernelLauncher
<
__nv_bfloat16
>
(
...
...
@@ -584,6 +695,8 @@ void topk_softmax(
num_experts
,
topk
,
renormalize
,
moe_softcapping_f
,
bias_ptr
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported gating_output dtype: "
,
dtype
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
c0652d90
...
...
@@ -191,32 +191,6 @@ void fast_topk_transform_ragged_interface(
void
gelu_quick
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
input
);
#endif
/*
* From gguf quantization
*/
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
,
std
::
optional
<
at
::
ScalarType
>
const
&
dtype
);
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_moe_a8
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
);
torch
::
Tensor
ggml_moe_a8_vec
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
topk_ids
,
int64_t
top_k
,
int64_t
type
,
int64_t
row
,
int64_t
tokens
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
/*
* From csrc/gemm
*/
...
...
@@ -333,7 +307,12 @@ void moe_align_block_size(
bool
pad_sorted_token_ids
);
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
gating_output
,
bool
renormalize
);
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
gating_output
,
bool
renormalize
,
double
moe_softcapping
,
const
c10
::
optional
<
torch
::
Tensor
>&
correction_bias
);
void
moe_sum_reduce
(
at
::
Tensor
&
input
,
at
::
Tensor
&
output
,
double
routed_scaling_factor
);
...
...
@@ -417,6 +396,7 @@ void silu_and_mul_scaled_fp4_experts_quant(
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
mask
,
bool
use_silu_and_mul
);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
...
...
@@ -445,7 +425,9 @@ void cutlass_w4a8_moe_mm(
torch
::
Tensor
const
&
s_strides
,
int64_t
chunk_size
,
int64_t
topk
);
/*
* From csrc/moe/marlin_moe_wna16
*/
torch
::
Tensor
moe_wna16_marlin_gemm
(
torch
::
Tensor
&
a
,
std
::
optional
<
torch
::
Tensor
>
const
&
c_or_none
,
...
...
@@ -680,6 +662,11 @@ void transfer_kv_all_layer_direct_lf_pf(
const
at
::
Tensor
&
dst_indices
,
int64_t
page_size
);
/*
* From csrc/memory
*/
void
store_kv_cache
(
at
::
Tensor
k_cache
,
at
::
Tensor
v_cache
,
at
::
Tensor
out_loc
,
at
::
Tensor
k
,
at
::
Tensor
v
);
/*
* From FlashInfer
*/
...
...
@@ -798,12 +785,12 @@ void convert_vertical_slash_indexes_mergehead(
bool
causal
);
/*
* From
XG
rammar
* From
csrc/g
rammar
*/
void
ApplyTokenBitmaskInplace
(
at
::
Tensor
logits
,
at
::
Tensor
bitmask
,
at
::
optional
<
at
::
Tensor
>
indices
=
at
::
nullopt
);
/*
* From QServe
* From
csrc/gemm (
QServe
)
*/
void
qserve_w4a8_per_chn_gemm
(
const
torch
::
Tensor
&
_in_feats
,
...
...
@@ -824,14 +811,35 @@ void qserve_w4a8_per_group_gemm(
torch
::
Tensor
&
_out_feats
);
/*
* From csrc/
spatial
* From csrc/
quantization/gguf
*/
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
);
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
,
std
::
optional
<
at
::
ScalarType
>
const
&
dtype
);
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_moe_a8
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
);
torch
::
Tensor
ggml_moe_a8_vec
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
topk_ids
,
int64_t
top_k
,
int64_t
type
,
int64_t
row
,
int64_t
tokens
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
/*
* From csrc/
memory
* From csrc/
spatial
*/
void
store_kv_cache
(
at
::
Tensor
k_cache
,
at
::
Tensor
v_cache
,
at
::
Tensor
out_loc
,
at
::
Tensor
k
,
at
::
Tensor
v
);
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
);
/*
* From csrc/mamba
...
...
@@ -883,7 +891,7 @@ torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
torch
::
Tensor
fast_hadamard_transform_40N
(
torch
::
Tensor
&
x
,
double
scale
);
/*
* From
csrc/fastertransformer
* From
flashmla
*/
std
::
vector
<
at
::
Tensor
>
get_mla_decoding_metadata
(
at
::
Tensor
&
seqlens_k
,
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
c0652d90
import
ctypes
import
logging
import
os
import
shutil
from
pathlib
import
Path
from
typing
import
List
import
torch
logger
=
logging
.
getLogger
(
__name__
)
def
_get_compute_capability
():
"""Get the compute capability of the current GPU."""
if
not
torch
.
cuda
.
is_available
():
return
None
# Get the current device
device
=
torch
.
cuda
.
current_device
()
properties
=
torch
.
cuda
.
get_device_properties
(
device
)
# Return as integer (major * 10 + minor)
return
properties
.
major
*
10
+
properties
.
minor
def
_filter_compiled_extensions
(
file_list
):
"""Filter and prioritize compiled extensions over Python source files."""
compiled_extensions
=
[
".so"
,
".pyd"
,
".dll"
]
# Common compiled extension suffixes
compiled_files
=
[]
other_files
=
[]
for
file_path
in
file_list
:
path
=
Path
(
file_path
)
# Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so)
if
any
(
str
(
path
).
endswith
(
ext
)
or
ext
in
str
(
path
)
for
ext
in
compiled_extensions
):
compiled_files
.
append
(
file_path
)
else
:
other_files
.
append
(
file_path
)
# Return compiled files first, then others
return
compiled_files
+
other_files
def
_load_architecture_specific_ops
():
"""Load the appropriate common_ops library based on GPU architecture."""
import
importlib.util
import
sys
from
pathlib
import
Path
compute_capability
=
_get_compute_capability
()
logger
.
debug
(
f
"[sgl_kernel] GPU Detection: compute_capability =
{
compute_capability
}
"
)
# Get the directory where sgl_kernel is installed
sgl_kernel_dir
=
Path
(
__file__
).
parent
logger
.
debug
(
f
"[sgl_kernel] sgl_kernel directory:
{
sgl_kernel_dir
}
"
)
# Determine which version to load based on GPU architecture
if
compute_capability
==
90
:
ops_subdir
=
"sm90"
variant_name
=
"SM90 (Hopper/H100 with fast math optimization)"
elif
compute_capability
is
not
None
:
ops_subdir
=
"sm100"
variant_name
=
f
"SM
{
compute_capability
}
(precise math for compatibility)"
else
:
ops_subdir
=
"sm100"
variant_name
=
"CPU/No GPU detected (using precise math)"
# Look for the compiled module with any valid extension
import
glob
ops_pattern
=
str
(
sgl_kernel_dir
/
ops_subdir
/
"common_ops.*"
)
raw_matching_files
=
glob
.
glob
(
ops_pattern
)
matching_files
=
_filter_compiled_extensions
(
raw_matching_files
)
logger
.
debug
(
f
"[sgl_kernel] Attempting to load
{
variant_name
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Looking for library matching pattern:
{
ops_pattern
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Found files:
{
raw_matching_files
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Prioritized files:
{
matching_files
}
"
)
previous_import_errors
:
List
[
Exception
]
=
[]
# Try to load from the architecture-specific directory
if
matching_files
:
ops_path
=
Path
(
matching_files
[
0
])
# Use the first prioritized file
logger
.
debug
(
f
"[sgl_kernel] Found architecture-specific library:
{
ops_path
}
"
)
try
:
# Load the module from specific path using importlib
spec
=
importlib
.
util
.
spec_from_file_location
(
"common_ops"
,
str
(
ops_path
))
if
spec
is
None
:
raise
ImportError
(
f
"Could not create module spec for
{
ops_path
}
"
)
common_ops
=
importlib
.
util
.
module_from_spec
(
spec
)
if
spec
.
loader
is
None
:
raise
ImportError
(
f
"Module spec has no loader for
{
ops_path
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Loading module from
{
ops_path
}
..."
)
spec
.
loader
.
exec_module
(
common_ops
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Successfully loaded
{
variant_name
}
"
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Module file:
{
common_ops
.
__file__
}
"
)
return
common_ops
except
Exception
as
e
:
previous_import_errors
.
append
(
e
)
logger
.
debug
(
f
"[sgl_kernel] ✗ Failed to load from
{
ops_path
}
:
{
type
(
e
).
__name__
}
:
{
e
}
"
)
# Continue to fallback
else
:
logger
.
debug
(
f
"[sgl_kernel] ✗ Architecture-specific library not found matching pattern:
{
ops_pattern
}
"
)
# Try alternative directory (in case installation structure differs)
alt_pattern
=
str
(
sgl_kernel_dir
/
"common_ops.*"
)
raw_alt_files
=
glob
.
glob
(
alt_pattern
)
alt_matching_files
=
_filter_compiled_extensions
(
raw_alt_files
)
logger
.
debug
(
f
"[sgl_kernel] Attempting fallback: looking for pattern
{
alt_pattern
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Found fallback files:
{
raw_alt_files
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Prioritized fallback files:
{
alt_matching_files
}
"
)
if
alt_matching_files
:
alt_path
=
Path
(
alt_matching_files
[
0
])
# Use the first prioritized file
logger
.
debug
(
f
"[sgl_kernel] Found fallback library:
{
alt_path
}
"
)
try
:
spec
=
importlib
.
util
.
spec_from_file_location
(
"common_ops"
,
str
(
alt_path
))
if
spec
is
None
:
raise
ImportError
(
f
"Could not create module spec for
{
alt_path
}
"
)
common_ops
=
importlib
.
util
.
module_from_spec
(
spec
)
if
spec
.
loader
is
None
:
raise
ImportError
(
f
"Module spec has no loader for
{
alt_path
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Loading fallback module from
{
alt_path
}
..."
)
spec
.
loader
.
exec_module
(
common_ops
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Successfully loaded fallback library"
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Module file:
{
common_ops
.
__file__
}
"
)
return
common_ops
except
Exception
as
e
:
previous_import_errors
.
append
(
e
)
logger
.
debug
(
f
"[sgl_kernel] ✗ Failed to load fallback from
{
alt_path
}
:
{
type
(
e
).
__name__
}
:
{
e
}
"
)
else
:
logger
.
debug
(
f
"[sgl_kernel] ✗ Fallback library not found matching pattern:
{
alt_pattern
}
"
)
# Final attempt: try standard Python import (for backward compatibility)
logger
.
debug
(
f
"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'"
)
try
:
import
common_ops
logger
.
debug
(
f
"[sgl_kernel] ✓ Successfully imported via standard Python import"
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Module file:
{
common_ops
.
__file__
}
"
)
return
common_ops
except
ImportError
as
e
:
previous_import_errors
.
append
(
e
)
logger
.
debug
(
f
"[sgl_kernel] ✗ Standard Python import failed:
{
e
}
"
)
attempt_error_msg
=
"
\n
"
.
join
(
f
"-
{
type
(
err
).
__name__
}
:
{
err
}
"
for
err
in
previous_import_errors
)
# All attempts failed
error_msg
=
f
"""
[sgl_kernel] CRITICAL: Could not load any common_ops library!
Attempted locations:
1. Architecture-specific pattern:
{
ops_pattern
}
- found files:
{
matching_files
}
2. Fallback pattern:
{
alt_pattern
}
- found files:
{
alt_matching_files
}
3. Standard Python import: common_ops - failed
GPU Info:
- Compute capability:
{
compute_capability
}
- Expected variant:
{
variant_name
}
Please ensure sgl_kernel is properly installed with:
pip install --upgrade sgl_kernel
Error details from previous import attempts:
{
attempt_error_msg
}
"""
logger
.
debug
(
error_msg
)
raise
ImportError
(
error_msg
)
from
sgl_kernel.load_utils
import
_load_architecture_specific_ops
,
_preload_cuda_library
# Initialize the ops library based on current GPU
logger
.
debug
(
"[sgl_kernel] Initializing architecture-specific operator library..."
)
common_ops
=
_load_architecture_specific_ops
()
logger
.
debug
(
"[sgl_kernel] ✓ Operator library initialization complete"
)
# copy & modify from torch/utils/cpp_extension.py
def
_find_cuda_home
():
"""Find the CUDA install path."""
# Guess #1
cuda_home
=
os
.
environ
.
get
(
"CUDA_HOME"
)
or
os
.
environ
.
get
(
"CUDA_PATH"
)
if
cuda_home
is
None
:
# Guess #2
nvcc_path
=
shutil
.
which
(
"nvcc"
)
if
nvcc_path
is
not
None
:
cuda_home
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
nvcc_path
))
else
:
# Guess #3
cuda_home
=
"/usr/local/cuda"
return
cuda_home
# Preload the CUDA library to avoid the issue of libcudart.so.12 not found
if
torch
.
version
.
cuda
is
not
None
:
cuda_home
=
Path
(
_find_cuda_home
())
if
(
cuda_home
/
"lib"
).
is_dir
():
cuda_path
=
cuda_home
/
"lib"
elif
(
cuda_home
/
"lib64"
).
is_dir
():
cuda_path
=
cuda_home
/
"lib64"
else
:
# Search for 'libcudart.so.12' in subdirectories
for
path
in
cuda_home
.
rglob
(
"libcudart.so.12"
):
cuda_path
=
path
.
parent
break
else
:
raise
RuntimeError
(
"Could not find CUDA lib directory."
)
_preload_cuda_library
()
cuda_include
=
(
cuda_path
/
"libcudart.so.12"
).
resolve
()
if
cuda_include
.
exists
():
ctypes
.
CDLL
(
str
(
cuda_include
),
mode
=
ctypes
.
RTLD_GLOBAL
)
from
sgl_kernel.allreduce
import
*
from
sgl_kernel.attention
import
(
...
...
sgl-kernel/python/sgl_kernel/load_utils.py
0 → 100644
View file @
c0652d90
import
ctypes
import
glob
import
importlib.util
import
logging
import
os
import
shutil
from
pathlib
import
Path
from
typing
import
List
import
torch
logger
=
logging
.
getLogger
(
__name__
)
def
_get_compute_capability
():
"""Get the compute capability of the current GPU."""
if
not
torch
.
cuda
.
is_available
():
return
None
# Get the current device
device
=
torch
.
cuda
.
current_device
()
properties
=
torch
.
cuda
.
get_device_properties
(
device
)
# Return as integer (major * 10 + minor)
return
properties
.
major
*
10
+
properties
.
minor
def
_filter_compiled_extensions
(
file_list
):
"""Filter and prioritize compiled extensions over Python source files."""
compiled_extensions
=
[
".so"
,
".pyd"
,
".dll"
]
# Common compiled extension suffixes
compiled_files
=
[]
other_files
=
[]
for
file_path
in
file_list
:
path
=
Path
(
file_path
)
# Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so)
if
any
(
str
(
path
).
endswith
(
ext
)
or
ext
in
str
(
path
)
for
ext
in
compiled_extensions
):
compiled_files
.
append
(
file_path
)
else
:
other_files
.
append
(
file_path
)
# Return compiled files first, then others
return
compiled_files
+
other_files
def
_load_architecture_specific_ops
():
"""Load the appropriate common_ops library based on GPU architecture."""
compute_capability
=
_get_compute_capability
()
logger
.
debug
(
f
"[sgl_kernel] GPU Detection: compute_capability =
{
compute_capability
}
"
)
# Get the directory where sgl_kernel is installed
sgl_kernel_dir
=
Path
(
__file__
).
parent
logger
.
debug
(
f
"[sgl_kernel] sgl_kernel directory:
{
sgl_kernel_dir
}
"
)
# Determine which version to load based on GPU architecture
if
compute_capability
==
90
:
ops_subdir
=
"sm90"
variant_name
=
"SM90 (Hopper/H100 with fast math optimization)"
elif
compute_capability
is
not
None
:
ops_subdir
=
"sm100"
variant_name
=
f
"SM
{
compute_capability
}
(precise math for compatibility)"
else
:
ops_subdir
=
"sm100"
variant_name
=
"CPU/No GPU detected (using precise math)"
# Look for the compiled module with any valid extension
ops_pattern
=
str
(
sgl_kernel_dir
/
ops_subdir
/
"common_ops.*"
)
raw_matching_files
=
glob
.
glob
(
ops_pattern
)
matching_files
=
_filter_compiled_extensions
(
raw_matching_files
)
logger
.
debug
(
f
"[sgl_kernel] Attempting to load
{
variant_name
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Looking for library matching pattern:
{
ops_pattern
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Found files:
{
raw_matching_files
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Prioritized files:
{
matching_files
}
"
)
previous_import_errors
:
List
[
Exception
]
=
[]
# Try to load from the architecture-specific directory
if
matching_files
:
ops_path
=
Path
(
matching_files
[
0
])
# Use the first prioritized file
logger
.
debug
(
f
"[sgl_kernel] Found architecture-specific library:
{
ops_path
}
"
)
try
:
# Load the module from specific path using importlib
spec
=
importlib
.
util
.
spec_from_file_location
(
"common_ops"
,
str
(
ops_path
))
if
spec
is
None
:
raise
ImportError
(
f
"Could not create module spec for
{
ops_path
}
"
)
common_ops
=
importlib
.
util
.
module_from_spec
(
spec
)
if
spec
.
loader
is
None
:
raise
ImportError
(
f
"Module spec has no loader for
{
ops_path
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Loading module from
{
ops_path
}
..."
)
spec
.
loader
.
exec_module
(
common_ops
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Successfully loaded
{
variant_name
}
"
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Module file:
{
common_ops
.
__file__
}
"
)
return
common_ops
except
Exception
as
e
:
previous_import_errors
.
append
(
e
)
logger
.
debug
(
f
"[sgl_kernel] ✗ Failed to load from
{
ops_path
}
:
{
type
(
e
).
__name__
}
:
{
e
}
"
)
# Continue to fallback
else
:
logger
.
debug
(
f
"[sgl_kernel] ✗ Architecture-specific library not found matching pattern:
{
ops_pattern
}
"
)
# Try alternative directory (in case installation structure differs)
alt_pattern
=
str
(
sgl_kernel_dir
/
"common_ops.*"
)
raw_alt_files
=
glob
.
glob
(
alt_pattern
)
alt_matching_files
=
_filter_compiled_extensions
(
raw_alt_files
)
logger
.
debug
(
f
"[sgl_kernel] Attempting fallback: looking for pattern
{
alt_pattern
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Found fallback files:
{
raw_alt_files
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Prioritized fallback files:
{
alt_matching_files
}
"
)
if
alt_matching_files
:
alt_path
=
Path
(
alt_matching_files
[
0
])
# Use the first prioritized file
logger
.
debug
(
f
"[sgl_kernel] Found fallback library:
{
alt_path
}
"
)
try
:
spec
=
importlib
.
util
.
spec_from_file_location
(
"common_ops"
,
str
(
alt_path
))
if
spec
is
None
:
raise
ImportError
(
f
"Could not create module spec for
{
alt_path
}
"
)
common_ops
=
importlib
.
util
.
module_from_spec
(
spec
)
if
spec
.
loader
is
None
:
raise
ImportError
(
f
"Module spec has no loader for
{
alt_path
}
"
)
logger
.
debug
(
f
"[sgl_kernel] Loading fallback module from
{
alt_path
}
..."
)
spec
.
loader
.
exec_module
(
common_ops
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Successfully loaded fallback library"
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Module file:
{
common_ops
.
__file__
}
"
)
return
common_ops
except
Exception
as
e
:
previous_import_errors
.
append
(
e
)
logger
.
debug
(
f
"[sgl_kernel] ✗ Failed to load fallback from
{
alt_path
}
:
{
type
(
e
).
__name__
}
:
{
e
}
"
)
else
:
logger
.
debug
(
f
"[sgl_kernel] ✗ Fallback library not found matching pattern:
{
alt_pattern
}
"
)
# Final attempt: try standard Python import (for backward compatibility)
logger
.
debug
(
f
"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'"
)
try
:
import
common_ops
logger
.
debug
(
f
"[sgl_kernel] ✓ Successfully imported via standard Python import"
)
logger
.
debug
(
f
"[sgl_kernel] ✓ Module file:
{
common_ops
.
__file__
}
"
)
return
common_ops
except
ImportError
as
e
:
previous_import_errors
.
append
(
e
)
logger
.
debug
(
f
"[sgl_kernel] ✗ Standard Python import failed:
{
e
}
"
)
attempt_error_msg
=
"
\n
"
.
join
(
f
"-
{
type
(
err
).
__name__
}
:
{
err
}
"
for
err
in
previous_import_errors
)
# All attempts failed
error_msg
=
f
"""
[sgl_kernel] CRITICAL: Could not load any common_ops library!
Attempted locations:
1. Architecture-specific pattern:
{
ops_pattern
}
- found files:
{
matching_files
}
2. Fallback pattern:
{
alt_pattern
}
- found files:
{
alt_matching_files
}
3. Standard Python import: common_ops - failed
GPU Info:
- Compute capability:
{
compute_capability
}
- Expected variant:
{
variant_name
}
Please ensure sgl_kernel is properly installed with:
pip install --upgrade sgl_kernel
Error details from previous import attempts:
{
attempt_error_msg
}
"""
logger
.
debug
(
error_msg
)
raise
ImportError
(
error_msg
)
# copy & modify from torch/utils/cpp_extension.py
def
_find_cuda_home
():
"""Find the CUDA install path."""
# Guess #1
cuda_home
=
os
.
environ
.
get
(
"CUDA_HOME"
)
or
os
.
environ
.
get
(
"CUDA_PATH"
)
if
cuda_home
is
None
:
# Guess #2
nvcc_path
=
shutil
.
which
(
"nvcc"
)
if
nvcc_path
is
not
None
:
cuda_home
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
nvcc_path
))
else
:
# Guess #3
cuda_home
=
"/usr/local/cuda"
return
cuda_home
def
_preload_cuda_library
():
cuda_home
=
Path
(
_find_cuda_home
())
if
(
cuda_home
/
"lib"
).
is_dir
():
cuda_path
=
cuda_home
/
"lib"
elif
(
cuda_home
/
"lib64"
).
is_dir
():
cuda_path
=
cuda_home
/
"lib64"
else
:
# Search for 'libcudart.so.12' in subdirectories
for
path
in
cuda_home
.
rglob
(
"libcudart.so.12"
):
cuda_path
=
path
.
parent
break
else
:
raise
RuntimeError
(
"Could not find CUDA lib directory."
)
cuda_include
=
(
cuda_path
/
"libcudart.so.12"
).
resolve
()
if
cuda_include
.
exists
():
ctypes
.
CDLL
(
str
(
cuda_include
),
mode
=
ctypes
.
RTLD_GLOBAL
)
sgl-kernel/python/sgl_kernel/moe.py
View file @
c0652d90
...
...
@@ -28,11 +28,29 @@ def moe_align_block_size(
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
gating_output
:
float
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
=
False
,
moe_softcapping
:
float
=
0.0
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
"""
Compute top-k softmax for MoE routing.
Args:
topk_weights: Output tensor for top-k weights [num_tokens, topk]
topk_ids: Output tensor for top-k expert indices [num_tokens, topk]
gating_output: Gating logits [num_tokens, num_experts]
renormalize: Whether to renormalize the top-k weights
moe_softcapping: Tanh softcapping value (0.0 to disable)
correction_bias: Per-expert bias correction [num_experts], must be float32 if provided
"""
torch
.
ops
.
sgl_kernel
.
topk_softmax
.
default
(
topk_weights
,
topk_ids
,
gating_output
,
renormalize
topk_weights
,
topk_ids
,
gating_output
,
renormalize
,
moe_softcapping
,
correction_bias
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment