sgl_kernel_ops.h 12.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

16
#pragma once
17

18
19
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
20
#include <Python.h>
21
22
#include <torch/library.h>
#include <torch/torch.h>
23

24
25
#include <vector>

26
27
#include "sgl_kernel_torch_shim.h"

28
29
30
31
32
33
34
35
36
37
38
39
40
41
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)

#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)

#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)

#define REGISTER_EXTENSION(NAME)                                                                      \
  PyMODINIT_FUNC CONCAT(PyInit_, NAME)() {                                                            \
    static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
    return PyModule_Create(&module);                                                                  \
  }

Ke Bao's avatar
Ke Bao committed
42
using fptr_t = int64_t;
43
44
45
46

/*
 * From csrc/allreduce
 */
47
#ifdef USE_ROCM
48
// ROCM custom allreduce
49
50
51
52
53
54
55
fptr_t init_custom_ar(
    torch::Tensor& meta,
    torch::Tensor& rank_data,
    const std::vector<std::string>& handles,
    const std::vector<int64_t>& offsets,
    int64_t rank,
    bool full_nvlink);
56
57
58
59
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out);
void dispose(fptr_t _fa);
int64_t meta_size();
60
61
void register_buffer(
    fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles, const std::vector<int64_t>& offsets);
62
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
63
64
void register_graph_buffers(
    fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets);
65
66
67
torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#else
68
// TRTLLM custom allreduce
69
70
71
72
73
74
75
76
fptr_t init_custom_ar(
    int64_t rank_id,
    int64_t world_size,
    torch::Tensor& rank_data,
    const std::vector<fptr_t>& buffers,
    const std::vector<fptr_t>& tmp_result_buffers,
    const std::vector<fptr_t>& barrier_in,
    const std::vector<fptr_t>& barrier_out);
Ke Bao's avatar
Ke Bao committed
77
78
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
79
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
80
81
void register_graph_buffers(
    fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
82
#endif
Ke Bao's avatar
Ke Bao committed
83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/*
 * From csrc/attention
 */
void lightning_attention_decode(
    const torch::Tensor& q,
    const torch::Tensor& k,
    const torch::Tensor& v,
    const torch::Tensor& past_kv,
    const torch::Tensor& slope,
    torch::Tensor output,
    torch::Tensor new_kv);

/*
 * From csrc/elementwise
 */
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void gemma_fused_add_rmsnorm(
    at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
107
108
109
110
111
112
113
114
115
void apply_rope_pos_ids_cos_sin_cache(
    at::Tensor q,
    at::Tensor k,
    at::Tensor q_rope,
    at::Tensor k_rope,
    at::Tensor cos_sin_cache,
    at::Tensor pos_ids,
    bool interleave,
    int64_t cuda_stream);
116

117
118
119
/*
 * From csrc/gemm
 */
120
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros);
Trevor Morris's avatar
Trevor Morris committed
121
122
123
124
125
126
127
void cutlass_scaled_fp4_mm(
    torch::Tensor& D,
    torch::Tensor const& A,
    torch::Tensor const& B,
    torch::Tensor const& A_sf,
    torch::Tensor const& B_sf,
    torch::Tensor const& alpha);
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
torch::Tensor int8_scaled_mm(
    const torch::Tensor& mat_a,
    const torch::Tensor& mat_b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const torch::Dtype& out_dtype,
    const c10::optional<torch::Tensor>& bias);
torch::Tensor fp8_scaled_mm(
    const torch::Tensor& mat_a,
    const torch::Tensor& mat_b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const torch::Dtype& out_dtype,
    const c10::optional<torch::Tensor>& bias);
torch::Tensor fp8_blockwise_scaled_mm(
    const torch::Tensor& mat_a,
    const torch::Tensor& mat_b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b,
    const torch::Dtype& out_dtype);
Trevor Morris's avatar
Trevor Morris committed
148
149
void scaled_fp4_quant(
    torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
150
151
152
153
154
155
156
157
void sgl_per_token_group_quant_fp8(
    at::Tensor input,
    at::Tensor output_q,
    at::Tensor output_s,
    int64_t group_size,
    double eps,
    double fp8_min,
    double fp8_max);
158
159
160
161
162
163
164
165
void sgl_per_token_group_quant_int8(
    at::Tensor input,
    at::Tensor output_q,
    at::Tensor output_s,
    int64_t group_size,
    double eps,
    double int8_min,
    double int8_max);
166
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
167
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
168
169
170
171
172
173
174
void cublas_grouped_gemm(
    const std::vector<torch::Tensor>& inputs,
    const std::vector<torch::Tensor>& weights,
    const std::vector<torch::Tensor>& outputs,
    const torch::Dtype& out_dtype,
    int64_t cublas_handle,
    int64_t cuda_stream);
175
176
177
178
179
180
181
182
183
void bmm_fp8(
    at::Tensor A,
    at::Tensor B,
    at::Tensor D,
    at::Tensor A_scale,
    at::Tensor B_scale,
    at::Tensor workspace_buffer,
    int64_t cublas_handle,
    int64_t cuda_stream);
184

185
186
187
/*
 * From csrc/moe
 */
188
189
190
191
192
193
194
195
196
void moe_align_block_size(
    torch::Tensor topk_ids,
    int64_t num_experts,
    int64_t block_size,
    torch::Tensor sorted_token_ids,
    torch::Tensor experts_ids,
    torch::Tensor num_tokens_post_pad,
    torch::Tensor token_cnts_buffer,
    torch::Tensor cumsum_buffer);
197

198
199
200
201
202
203
void topk_softmax(
    torch::Tensor& topk_weights,
    torch::Tensor& topk_indices,
    torch::Tensor& token_expert_indices,
    torch::Tensor& gating_output);

204
205
206
std::vector<at::Tensor>
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk);

207
208
209
/*
 * From csrc/speculative
 */
210
void tree_speculative_sampling_target_only(
211
212
    at::Tensor predicts,          // mutable
    at::Tensor accept_index,      // mutable
213
214
215
216
217
218
219
220
    at::Tensor accept_token_num,  // mutable
    at::Tensor candidates,
    at::Tensor retrive_index,
    at::Tensor retrive_next_token,
    at::Tensor retrive_next_sibling,
    at::Tensor uniform_samples,
    at::Tensor target_probs,
    at::Tensor draft_probs,
221
222
    double threshold_single = 1,
    double threshold_acc = 1,
223
224
225
    bool deterministic = true,
    int64_t cuda_stream = 0);

226
227
228
229
230
void verify_tree_greedy(
    at::Tensor predicts,          // mutable
    at::Tensor accept_index,      // mutable
    at::Tensor accept_token_num,  // mutable
    at::Tensor candidates,
231
232
233
    at::Tensor retrive_index,
    at::Tensor retrive_next_token,
    at::Tensor retrive_next_sibling,
234
235
    at::Tensor target_predict,
    int64_t cuda_stream = 0);
236

237
void build_tree_kernel_efficient(
238
239
240
241
242
243
    at::Tensor parent_list,
    at::Tensor selected_index,
    at::Tensor verified_seq_len,
    at::Tensor tree_mask,
    at::Tensor positions,
    at::Tensor retrive_index,
244
245
    at::Tensor retrive_next_token,
    at::Tensor retrive_next_sibling,
246
247
248
    int64_t topk,
    int64_t depth,
    int64_t draft_token_num);
249

250
251
252
void segment_packbits(
    at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream);

253
254
255
/*
 * From FlashInfer
 */
256
257
258
259
260
261
262
263
264
265
266
267
268
void min_p_sampling_from_probs(
    at::Tensor probs,
    at::Tensor uniform_samples,
    at::Tensor samples,
    std::optional<at::Tensor> maybe_min_p_arr,
    double min_p_val,
    bool deterministic,
    int64_t cuda_stream);
void top_k_renorm_probs(
    at::Tensor probs,
    at::Tensor renorm_probs,
    std::optional<at::Tensor> maybe_top_k_arr,
    int64_t top_k_val,
269
    int64_t cuda_stream);
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
void top_p_renorm_probs(
    at::Tensor probs,
    at::Tensor renorm_probs,
    std::optional<at::Tensor> maybe_top_p_arr,
    double top_p_val,
    int64_t cuda_stream);
void top_k_top_p_sampling_from_probs(
    at::Tensor probs,
    at::Tensor uniform_samples,
    at::Tensor samples,
    at::Tensor success,
    std::optional<at::Tensor> maybe_top_k_arr,
    double top_k_val,
    std::optional<at::Tensor> maybe_top_p_arr,
    double top_p_val,
    bool deterministic,
    int64_t cuda_stream);
void top_p_sampling_from_probs(
    at::Tensor probs,
    at::Tensor uniform_samples,
    at::Tensor samples,
    at::Tensor success,
    std::optional<at::Tensor> maybe_top_p_arr,
    double top_p_val,
    bool deterministic,
    int64_t cuda_stream);
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

/*
 * From flash-attention
 */
std::vector<at::Tensor> mha_fwd(
    at::Tensor& q,        // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
    const at::Tensor& k,  // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
                          // h_k, d) if there is page_table.
    const at::Tensor& v,  // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
                          // page_size, h_k, dv) if there is page_table.
    std::optional<const at::Tensor>&
        k_new_,  // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
    std::optional<const at::Tensor>&
        v_new_,  // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
    std::optional<const at::Tensor>& q_v_,           // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
    std::optional<at::Tensor>& out_,                 // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
    std::optional<const at::Tensor>& cu_seqlens_q_,  // b+1
    std::optional<const at::Tensor>& cu_seqlens_k_,  // b+1
    std::optional<const at::Tensor>& cu_seqlens_k_new_,  // b+1
    std::optional<const at::Tensor>&
        seqused_q_,  // b. If given, only this many elements of each batch element's queries and outputs are used.
    std::optional<const at::Tensor>&
        seqused_k_,  // b. If given, only this many elements of each batch element's keys are used.
    std::optional<int> max_seqlen_q_,
    // TODO: check if we need max_seqlen_k
    std::optional<int> max_seqlen_k_,
    std::optional<const at::Tensor>& page_table_,      // (b_k, max_num_pages_per_seq)
    std::optional<const at::Tensor>& kv_batch_idx_,    // b. indices to index into the KV cache
    std::optional<const at::Tensor>& leftpad_k_,       // b
    std::optional<const at::Tensor>& rotary_cos_,      // seqlen_ro x (rotary_dim / 2)
    std::optional<const at::Tensor>& rotary_sin_,      // seqlen_ro x (rotary_dim / 2)
    std::optional<const at::Tensor>& seqlens_rotary_,  // b
    std::optional<at::Tensor>& q_descale_,             // (b, h_k), not (b, h)
    std::optional<at::Tensor>& k_descale_,             // (b, h_k)
    std::optional<at::Tensor>& v_descale_,             // (b, h_k)
    float const softmax_scale,
    bool is_causal,
    int window_size_left,
    int window_size_right,
    float const softcap,
    bool const is_rotary_interleaved,  // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
    std::optional<at::Tensor>& scheduler_metadata_,  // (b + 1)
    int num_splits,
    std::optional<bool> pack_gqa_,
    int const sm_margin);