Unverified Commit c7457191 authored by yinfan98's avatar yinfan98 Committed by GitHub
Browse files

[Fix] revert clean m.def for cudagraph (#4944)

parent 51ac297a
...@@ -55,9 +55,17 @@ Steps to add a new kernel: ...@@ -55,9 +55,17 @@ Steps to add a new kernel:
1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `<torch/all.h>` instead of `<torch/extension.h>`. Using `<torch/extension.h>` will cause compilation errors when using SABI. 1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `<torch/all.h>` instead of `<torch/extension.h>`. Using `<torch/extension.h>` will cause compilation errors when using SABI.
2. When creating torch extensions, simply add the function definition with `m.def`: 2. When creating torch extensions, add the function definition with `m.def`, and device binding with `m.impl`:
- Using torch.compile need `m.def` with schema, it helps auto capture the custom kernel. Reference: [How to add FakeTensor](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0#heading=h.ptttacy8y1u9)
- How to write schema: [Schema reference](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func)
```cpp ```cpp
m.def("register_graph_buffers", register_graph_buffers); // We need def with schema here for torch.compile
m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
``` ```
3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels. 3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels.
...@@ -96,6 +104,8 @@ Steps to add a new kernel: ...@@ -96,6 +104,8 @@ Steps to add a new kernel:
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`. When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.
> The reason we need `double` and `int64_t` in torch binding is that TORCH_LIBRARY handles the `Python-to-C++` conversion process. Python's `float` data type actually corresponds to `double` in C++, while Python's `int` corresponds to `int64_t` in C++.
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically. To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.
When you need to support new data type conversions, you can easily add conversion functions like this: When you need to support new data type conversions, you can easily add conversion functions like this:
...@@ -119,7 +129,7 @@ To use this with your library functions, simply wrap them with make_pytorch_shim ...@@ -119,7 +129,7 @@ To use this with your library functions, simply wrap them with make_pytorch_shim
/* /*
* From flash-attention * From flash-attention
*/ */
m.def("fwd", make_pytorch_shim(mha_fwd)); m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
``` ```
### Build & Install ### Build & Install
......
...@@ -22,49 +22,121 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -22,49 +22,121 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/allreduce * From csrc/allreduce
*/ */
m.def("init_custom_ar", init_custom_ar);
m.def("dispose", dispose); m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("all_reduce", all_reduce); m.def("register_graph_buffers", &register_graph_buffers);
m.def("get_graph_buffer_ipc_meta", get_graph_buffer_ipc_meta); m.def("dispose", &dispose);
m.def("register_graph_buffers", register_graph_buffers);
m.def(
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
"barrier_in, int[] barrier_out) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);
/* /*
* From csrc/attention * From csrc/attention
*/ */
m.def("lightning_attention_decode", lightning_attention_decode); m.def(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"new_kv) -> ()");
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
/* /*
* From csrc/elementwise * From csrc/elementwise
*/ */
m.def("rmsnorm", rmsnorm); m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("fused_add_rmsnorm", sgl_fused_add_rmsnorm); m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
m.def("gemma_rmsnorm", gemma_rmsnorm);
m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm); m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
m.def("silu_and_mul", silu_and_mul); m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
m.def("gelu_tanh_and_mul", gelu_tanh_and_mul);
m.def("gelu_and_mul", gelu_and_mul); m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache); m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
m.def("awq_dequantize", awq_dequantize); m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor");
m.def("int8_scaled_mm", int8_scaled_mm); m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
m.def("fp8_scaled_mm", fp8_scaled_mm);
m.def("fp8_blockwise_scaled_mm", fp8_blockwise_scaled_mm); m.def(
m.def("sgl_per_token_group_quant_fp8", sgl_per_token_group_quant_fp8); "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
m.def("sgl_per_token_group_quant_int8", sgl_per_token_group_quant_int8); "bias) -> Tensor");
m.def("sgl_per_tensor_quant_fp8", sgl_per_tensor_quant_fp8); m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
m.def("sgl_per_token_quant_fp8", sgl_per_token_quant_fp8);
m.def("cublas_grouped_gemm", cublas_grouped_gemm); m.def(
m.def("cutlass_scaled_fp4_mm", cutlass_scaled_fp4_mm); "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
m.def("scaled_fp4_quant", scaled_fp4_quant); "bias) -> Tensor");
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
m.def(
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
"Tensor");
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
m.def(
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
m.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
m.def(
"scaled_fp4_quant(Tensor! output, Tensor! input,"
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
/* /*
* From csrc/moe * From csrc/moe
*/ */
m.def("moe_align_block_size", moe_align_block_size); m.def(
m.def("topk_softmax", topk_softmax); "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def( m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> " "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
...@@ -74,10 +146,28 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -74,10 +146,28 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
m.def("tree_speculative_sampling_target_only", tree_speculative_sampling_target_only); m.def(
m.def("verify_tree_greedy", verify_tree_greedy); "tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
m.def("build_tree_kernel_efficient", build_tree_kernel_efficient); "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
m.def("segment_packbits", segment_packbits); "Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, "
"bool deterministic, int cuda_stream) -> ()");
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()");
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
/* /*
* From FlashInfer * From FlashInfer
...@@ -86,16 +176,71 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -86,16 +176,71 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()"); "cublas_handle, int cuda_stream) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
m.def("min_p_sampling_from_probs", min_p_sampling_from_probs);
m.def("top_k_renorm_probs", top_k_renorm_probs); m.def(
m.def("top_p_renorm_probs", top_p_renorm_probs); "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs); "min_p_val, bool deterministic, int cuda_stream) -> ()");
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs); m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
m.def(
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()");
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
m.def(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
"cuda_stream) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
"cuda_stream) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
/* /*
* From flash-attention * From flash-attention
*/ */
m.def("fwd", make_pytorch_shim(mha_fwd)); m.def(
"fwd(Tensor! q,"
" Tensor k,"
" Tensor v,"
" Tensor? k_new,"
" Tensor? v_new,"
" Tensor? q_v,"
" Tensor!? out,"
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
" Tensor? k_descale,"
" Tensor? v_descale,"
" float softmax_scale,"
" bool is_causal,"
" int window_size_left,"
" int window_size_right,"
" float softcap,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]");
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment