Unverified Commit 0d7fe866 authored by yinfan98's avatar yinfan98 Committed by GitHub
Browse files

[Misc] Clean m.def and add Development Tips (#4890)

parent 54b9a2de
...@@ -51,6 +51,47 @@ Steps to add a new kernel: ...@@ -51,6 +51,47 @@ Steps to add a new kernel:
4. Update [CMakeLists.txt](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/CMakeLists.txt) to include new CUDA source 4. Update [CMakeLists.txt](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/CMakeLists.txt) to include new CUDA source
5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel) 5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel)
### Development Tips
1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `<torch/all.h>` instead of `<torch/extension.h>`. Using `<torch/extension.h>` will cause compilation errors when using SABI.
2. When creating torch extensions, simply add the function definition with `m.def`:
```cpp
m.def("register_graph_buffers", register_graph_buffers);
```
3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels.
**Avoid this:**
```cpp
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size),
k_rope=key.view(key.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache,
pos_ids=positions.long(),
interleave=(not is_neox),
cuda_stream=get_cuda_stream(),
)
```
**Use this instead:**
```cpp
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
query.view(query.shape[0], -1, head_size),
key.view(key.shape[0], -1, head_size),
query.view(query.shape[0], -1, head_size),
key.view(key.shape[0], -1, head_size),
cos_sin_cache,
positions.long(),
(not is_neox),
get_cuda_stream(),
)
```
### Build & Install ### Build & Install
Development build: Development build:
......
...@@ -22,121 +22,49 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -22,121 +22,49 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/allreduce * From csrc/allreduce
*/ */
m.def( m.def("init_custom_ar", init_custom_ar);
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " m.def("dispose", dispose);
"barrier_in, int[] barrier_out) -> int"); m.def("all_reduce", all_reduce);
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); m.def("get_graph_buffer_ipc_meta", get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", register_graph_buffers);
m.def("dispose", &dispose);
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
/* /*
* From csrc/attention * From csrc/attention
*/ */
m.def( m.def("lightning_attention_decode", lightning_attention_decode);
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"new_kv) -> ()");
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
/* /*
* From csrc/elementwise * From csrc/elementwise
*/ */
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); m.def("rmsnorm", rmsnorm);
m.impl("rmsnorm", torch::kCUDA, &rmsnorm); m.def("fused_add_rmsnorm", sgl_fused_add_rmsnorm);
m.def("gemma_rmsnorm", gemma_rmsnorm);
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm);
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); m.def("silu_and_mul", silu_and_mul);
m.def("gelu_tanh_and_mul", gelu_tanh_and_mul);
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); m.def("gelu_and_mul", gelu_and_mul);
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache);
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor"); m.def("awq_dequantize", awq_dequantize);
m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); m.def("int8_scaled_mm", int8_scaled_mm);
m.def("fp8_scaled_mm", fp8_scaled_mm);
m.def( m.def("fp8_blockwise_scaled_mm", fp8_blockwise_scaled_mm);
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " m.def("sgl_per_token_group_quant_fp8", sgl_per_token_group_quant_fp8);
"bias) -> Tensor"); m.def("sgl_per_token_group_quant_int8", sgl_per_token_group_quant_int8);
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); m.def("sgl_per_tensor_quant_fp8", sgl_per_tensor_quant_fp8);
m.def("sgl_per_token_quant_fp8", sgl_per_token_quant_fp8);
m.def( m.def("cublas_grouped_gemm", cublas_grouped_gemm);
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " m.def("cutlass_scaled_fp4_mm", cutlass_scaled_fp4_mm);
"bias) -> Tensor"); m.def("scaled_fp4_quant", scaled_fp4_quant);
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
m.def(
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
"Tensor");
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
m.def(
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
m.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
m.def(
"scaled_fp4_quant(Tensor! output, Tensor! input,"
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
/* /*
* From csrc/moe * From csrc/moe
*/ */
m.def( m.def("moe_align_block_size", moe_align_block_size);
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " m.def("topk_softmax", topk_softmax);
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def( m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> " "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
...@@ -146,62 +74,20 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -146,62 +74,20 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
m.def( m.def("tree_speculative_sampling_target_only", tree_speculative_sampling_target_only);
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " m.def("verify_tree_greedy", verify_tree_greedy);
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " m.def("build_tree_kernel_efficient", build_tree_kernel_efficient);
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, " m.def("segment_packbits", segment_packbits);
"float threshold_single, float threshold_acc, "
"bool deterministic, int cuda_stream) -> ()");
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()");
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
/* /*
* From FlashInfer * From FlashInfer
*/ */
m.def( m.def("bmm_fp8", bmm_fp8);
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " m.def("min_p_sampling_from_probs", min_p_sampling_from_probs);
"cublas_handle, int cuda_stream) -> ()"); m.def("top_k_renorm_probs", top_k_renorm_probs);
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); m.def("top_p_renorm_probs", top_p_renorm_probs);
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
m.def( m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
m.def(
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()");
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
m.def(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
"cuda_stream) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
"cuda_stream) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
...@@ -142,12 +142,12 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -142,12 +142,12 @@ def apply_rope_with_cos_sin_cache_inplace(
raise ValueError("cos_sin_cache should be float32") raise ValueError("cos_sin_cache should be float32")
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
q=query.view(query.shape[0], -1, head_size), query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size), key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size), query.view(query.shape[0], -1, head_size),
k_rope=key.view(key.shape[0], -1, head_size), key.view(key.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache, cos_sin_cache,
pos_ids=positions.long(), positions.long(),
interleave=(not is_neox), (not is_neox),
cuda_stream=get_cuda_stream(), get_cuda_stream(),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment