torch_extension.cc 3.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
15
#include <ATen/core/dispatch/Dispatcher.h>
16
#include <torch/all.h>
17
18
#include <torch/library.h>

19
#include "sgl_kernel_ops.h"
20

21
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
22
23
24
  /*
   * From csrc/allreduce
   */
25
26
27
28
29
  m.def("init_custom_ar", init_custom_ar);
  m.def("dispose", dispose);
  m.def("all_reduce", all_reduce);
  m.def("get_graph_buffer_ipc_meta", get_graph_buffer_ipc_meta);
  m.def("register_graph_buffers", register_graph_buffers);
30

31
32
33
  /*
   * From csrc/attention
   */
34
  m.def("lightning_attention_decode", lightning_attention_decode);
35

36
37
38
  /*
   * From csrc/elementwise
   */
39
40
41
42
43
44
45
46
  m.def("rmsnorm", rmsnorm);
  m.def("fused_add_rmsnorm", sgl_fused_add_rmsnorm);
  m.def("gemma_rmsnorm", gemma_rmsnorm);
  m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm);
  m.def("silu_and_mul", silu_and_mul);
  m.def("gelu_tanh_and_mul", gelu_tanh_and_mul);
  m.def("gelu_and_mul", gelu_and_mul);
  m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache);
47

48
49
50
  /*
   * From csrc/gemm
   */
51
52
53
54
55
56
57
58
59
60
61
  m.def("awq_dequantize", awq_dequantize);
  m.def("int8_scaled_mm", int8_scaled_mm);
  m.def("fp8_scaled_mm", fp8_scaled_mm);
  m.def("fp8_blockwise_scaled_mm", fp8_blockwise_scaled_mm);
  m.def("sgl_per_token_group_quant_fp8", sgl_per_token_group_quant_fp8);
  m.def("sgl_per_token_group_quant_int8", sgl_per_token_group_quant_int8);
  m.def("sgl_per_tensor_quant_fp8", sgl_per_tensor_quant_fp8);
  m.def("sgl_per_token_quant_fp8", sgl_per_token_quant_fp8);
  m.def("cublas_grouped_gemm", cublas_grouped_gemm);
  m.def("cutlass_scaled_fp4_mm", cutlass_scaled_fp4_mm);
  m.def("scaled_fp4_quant", scaled_fp4_quant);
Trevor Morris's avatar
Trevor Morris committed
62

63
64
65
  /*
   * From csrc/moe
   */
66
67
  m.def("moe_align_block_size", moe_align_block_size);
  m.def("topk_softmax", topk_softmax);
68

69
70
71
72
73
  m.def(
      "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
      "(Tensor[])");
  m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);

74
75
76
  /*
   * From csrc/speculative
   */
77
78
79
80
  m.def("tree_speculative_sampling_target_only", tree_speculative_sampling_target_only);
  m.def("verify_tree_greedy", verify_tree_greedy);
  m.def("build_tree_kernel_efficient", build_tree_kernel_efficient);
  m.def("segment_packbits", segment_packbits);
81

82
83
84
  /*
   * From FlashInfer
   */
Yineng Zhang's avatar
Yineng Zhang committed
85
86
87
88
  m.def(
      "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
      "cublas_handle, int cuda_stream) -> ()");
  m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
89
90
91
92
93
  m.def("min_p_sampling_from_probs", min_p_sampling_from_probs);
  m.def("top_k_renorm_probs", top_k_renorm_probs);
  m.def("top_p_renorm_probs", top_p_renorm_probs);
  m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
  m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);
94
95
96
97
98

  /*
   * From flash-attention
   */
  m.def("fwd", make_pytorch_shim(mha_fwd));
99
100
}

101
REGISTER_EXTENSION(common_ops)