torch_bindings.cpp 30.6 KB
Newer Older
1
2
3
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
4
#include "core/registration.h"
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

#include <torch/library.h>

// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  // vLLM custom ops

21
22
23
  ops.def("weak_ref_tensor(Tensor input) -> Tensor");
  ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);

24
25
26
27
28
29
30
31
32
  // Attention ops
  // Compute the attention between an input query and the cached
  // keys/values using PagedAttention.
  ops.def(
      "paged_attention_v1("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
33
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
34
      "    int tp_rank, int blocksparse_local_blocks,"
35
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
36
37
38
39
40
41
42
43
44
45
46
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);

  // PagedAttention V2.
  ops.def(
      "paged_attention_v2("
      "    Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
      "    Tensor! tmp_out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
47
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
48
49
50
51
52
53
54
55
56
57
58
59
60
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);

  // Compute the attention between an input query and the cached
  // keys/values using PagedAttention. (opt)
  ops.def(
      "paged_attention_v1_opt("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
61
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
62
63
64
65
66
67
68
69
70
71
72
73
74
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);

  // PagedAttention V2 (opt). 
  ops.def(
      "paged_attention_v2_opt("
      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
75
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
76
77
78
79
80
81
82
83
84
85
86
87
88
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);

  // Compute the attention between an input query and the cached
  // keys/values using PagedAttention. (opt)
  ops.def(
      "paged_attention_v1_opt_tc("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
89
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
90
91
92
93
94
95
96
97
98
99
100
101
102
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v1_opt_tc", torch::kCUDA, &paged_attention_v1_opt_tc);

  // PagedAttention V2 (opt). 
  ops.def(
      "paged_attention_v2_opt_tc("
      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
103
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
104
105
106
107
108
109
110
111
112
113
114
115
116
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v2_opt_tc", torch::kCUDA, &paged_attention_v2_opt_tc);


  // paged_attention with atth_masks
  ops.def(
      "paged_attention_v1_with_mask("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
117
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
118
119
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
120
121
122
      "    int blocksparse_head_sliding_step,"
      "    Tensor? attn_masks,"
      "    int attn_masks_stride) -> ()");
123
  ops.impl("paged_attention_v1_with_mask", torch::kCUDA, &paged_attention_v1_with_mask);
124
125
126

  // PagedAttention V2.
  ops.def(
127
      "paged_attention_v2_with_mask("
128
129
      "    Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
      "    Tensor! tmp_out, Tensor query, Tensor key_cache,"
130
131
132
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
133
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
134
      "    int tp_rank, int blocksparse_local_blocks,"
135
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
136
137
138
      "    int blocksparse_head_sliding_step,"
      "    Tensor? attn_masks,"
      "    int attn_masks_stride) -> ()");
139
  ops.impl("paged_attention_v2_with_mask", torch::kCUDA, &paged_attention_v2_with_mask);
140

zhuwenwen's avatar
zhuwenwen committed
141
142
143
  // Compute the attention between an input query and the cached
  // keys/values using PagedAttention. (opt)
  ops.def(
144
      "paged_attention_v1_opt_with_mask("
zhuwenwen's avatar
zhuwenwen committed
145
146
147
148
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
149
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
zhuwenwen's avatar
zhuwenwen committed
150
151
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
152
153
154
      "    int blocksparse_head_sliding_step,"
      "    Tensor? attn_masks,"
      "    int attn_masks_stride) -> ()");
155
  ops.impl("paged_attention_v1_opt_with_mask", torch::kCUDA, &paged_attention_v1_opt_with_mask);
zhuwenwen's avatar
zhuwenwen committed
156
157
158

  // PagedAttention V2 (opt). 
  ops.def(
159
      "paged_attention_v2_opt_with_mask("
zhuwenwen's avatar
zhuwenwen committed
160
161
162
163
164
      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
165
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
zhuwenwen's avatar
zhuwenwen committed
166
167
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
168
169
170
      "    int blocksparse_head_sliding_step,"
      "    Tensor? attn_masks,"
      "    int attn_masks_stride) -> ()");
171
  ops.impl("paged_attention_v2_opt_with_mask", torch::kCUDA, &paged_attention_v2_opt_with_mask);
zhuwenwen's avatar
zhuwenwen committed
172

173
  // Compute the attention between an input query and the cached
zhuwenwen's avatar
zhuwenwen committed
174
  // keys/values using PagedAttention. (opt)
175
  ops.def(
176
      "paged_attention_v1_opt_tc_with_mask("
177
178
179
180
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
181
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
zhuwenwen's avatar
zhuwenwen committed
182
      "    int tp_rank, int blocksparse_local_blocks,"
183
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
184
185
186
      "    int blocksparse_head_sliding_step,"
      "    Tensor? attn_masks,"
      "    int attn_masks_stride) -> ()");
187
  ops.impl("paged_attention_v1_opt_tc_with_mask", torch::kCUDA, &paged_attention_v1_opt_tc_with_mask);
188

zhuwenwen's avatar
zhuwenwen committed
189
  // PagedAttention V2 (opt). 
190
  ops.def(
191
      "paged_attention_v2_opt_tc_with_mask("
192
193
194
195
196
      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
197
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
zhuwenwen's avatar
zhuwenwen committed
198
      "    int tp_rank, int blocksparse_local_blocks,"
199
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
200
201
202
      "    int blocksparse_head_sliding_step,"
      "    Tensor? attn_masks,"
      "    int attn_masks_stride) -> ()");
203
204
  ops.impl("paged_attention_v2_opt_tc_with_mask", torch::kCUDA, &paged_attention_v2_opt_tc_with_mask);

205

206
207
208
209
210
  // Activation ops
  // Activation function used in SwiGLU.
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

211
212
213
  ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
  ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);

214
215
216
217
218
219
220
221
  // Activation function used in GeGLU with `none` approximation.
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation.
  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);

zhuwenwen's avatar
zhuwenwen committed
222
223
224
225
226
227
228
229
230
231
232
233
  // Activation function used in SwiGLU. (opt)
  ops.def("silu_and_mul_opt(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul_opt", torch::kCUDA, &silu_and_mul);

  // Activation function used in GeGLU with `none` approximation. (opt)
  ops.def("gelu_and_mul_opt(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul_opt", torch::kCUDA, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation. (opt)
  ops.def("gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul_opt", torch::kCUDA, &gelu_tanh_and_mul);

234
235
236
237
  // FATReLU implementation.
  ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
  ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);

238
239
240
241
242
243
244
245
  // GELU implementation used in GPT-2.
  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_new", torch::kCUDA, &gelu_new);

  // Approximate GELU implementation.
  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);

246
247
248
249
  // Quick GELU implementation.
  ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);

250
  // prepare_inputs advance_step
251
  ops.def(
252
      "advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
253
254
255
      "Tensor! input_tokens, Tensor sampled_token_ids, "
      "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
      "Tensor block_tables) -> ()");
256
257
258
259
260
261
262
263
264
265
266
267
  ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);

  ops.def(
      "advance_step_flashinfer("
      "    int num_seqs, int num_queries, int block_size,"
      "    Tensor! input_tokens, Tensor sampled_token_ids,"
      "    Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
      "    Tensor block_tables, Tensor! paged_kv_indices,"
      "    Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
      "    Tensor! block_table_bounds"
      ") -> ()");
  ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
268

269
270
271
  // Layernorm
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
  ops.def(
272
      "rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> "
273
274
275
276
277
278
279
280
281
      "()");
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);

  // In-place fused Add and RMS Normalization.
  ops.def(
      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);

zhuwenwen's avatar
zhuwenwen committed
282
283
284
285
286
287
288
289
290
291
292
293
  // Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
  ops.def(
      "rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
      "()");
  ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);

  // In-place fused Add and RMS Normalization. (opt)
  ops.def(
      "fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);

zhuwenwen's avatar
zhuwenwen committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
  // Layernorm-quant
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
//   ops.def(
//       "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
//       "Tensor scale, float epsilon) -> "
//       "()");
//   ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
//            &rms_norm_static_fp8_quant);

  // In-place fused Add and RMS Normalization.
//   ops.def(
//       "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
//       "Tensor! residual, Tensor weight, "
//       "Tensor scale, float epsilon) -> ()");
//   ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
//            &fused_add_rms_norm_static_fp8_quant);
310

311
312
313
314
315
316
317
318
  // Fused Layernorm + Quant kernels
  ops.def(
      "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
      "Tensor weight, Tensor! scale, float epsilon, "
      "Tensor? scale_ub, Tensor!? residual) -> ()");
  ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
           &rms_norm_dynamic_per_token_quant);

319
320
321
322
323
324
325
326
  // Rotary embedding
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding(Tensor positions, Tensor! query,"
      "                 Tensor! key, int head_size,"
      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
  ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);

huangwb's avatar
huangwb committed
327
328
329
330
331
332
333
334
335
  // Rotary embedding TGI for TGI
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding_tgi(Tensor! query, Tensor! key,"
      "                 int head_size, Tensor cos_cache,"
      "                 Tensor sin_cache, bool is_neox) -> ()");
//   ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
  ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);

336
337
338
339
340
341
342
343
344
345
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
  // (supports multiple loras).
  ops.def(
      "batched_rotary_embedding(Tensor positions, Tensor! query,"
      "                         Tensor! key, int head_size,"
      "                         Tensor cos_sin_cache, bool is_neox,"
      "                         int rot_dim,"
      "                         Tensor cos_sin_cache_offsets) -> ()");
  ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);

zhuwenwen's avatar
zhuwenwen committed
346
347
348
349
  // trans w16
  ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
  ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);

350
351
352
  // Quantization ops
#ifndef USE_ROCM
  // Quantized GEMM for AQLM.
353
354
355
356
  ops.def(
      "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
      "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
      "-> Tensor");
357
358
359
  ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);

  // Decompression method for AQLM.
360
361
362
  ops.def(
      "aqlm_dequant(Tensor codes, Tensor codebooks, "
      "int[] codebook_partition_sizes) -> Tensor");
363
364
365
  ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);

  // Quantized GEMM for AWQ.
366
367
  ops.def(
      "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
368
      "Tensor _zeros, SymInt split_k_iters) -> Tensor");
369
370
371
  ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);

  // Dequantization for AWQ.
372
373
  ops.def(
      "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
374
      "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
375
376
  ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);

377
378
379
380
381
382
383
384
385
386
387
388
389
390
  // Note about marlin kernel 'workspace' arguments:
  // Technically these should be mutable since they are modified by the kernel.
  // But since they are set back to zero once the kernel is finished we can
  // hand wave and say that they have no net effect.
  //
  // The reason to mark 'workspace' as immutable is so that they don't interfere
  // with using ScalarType arguments in the ops. If they are marked as mutable,
  // pytorch throws an assert in
  // 'torch._higher_order_ops._register_effectful_op' that prevents these
  // kernels from being torch.compile'd.
  // See the following document for more info on custom types and ops that use
  // custom types:
  // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA

391
  // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
392
393
  ops.def(
      "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
394
395
      "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
      "Tensor");
396
  // conditionally compiled so impl in source file
397
398

  // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
399
400
401
  ops.def(
      "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
      "Tensor b_scales, Tensor workspace, "
402
403
      "int b_q_type, "
      "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
404
  //  conditionally compiled so impl in source file
405

406
407
  // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
  ops.def(
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
      "machete_supported_schedules("
      "   ScalarType a_type,"
      "   int b_type,"
      "   ScalarType? maybe_group_scales_type,"
      "   ScalarType? maybe_group_zeros_type,"
      "   ScalarType? maybe_channel_scales_type,"
      "   ScalarType? maybe_token_scales_type,"
      "   ScalarType? maybe_out_type"
      ") -> str[]");
  ops.def(
      "machete_mm("
      "   Tensor A,"
      "   Tensor B,"
      "   int b_type,"
      "   ScalarType? out_type,"
      "   Tensor? group_scales,"
      "   Tensor? group_zeros,"
      "   int?    group_size,"
      "   Tensor? channel_scales,"
      "   Tensor? token_scales,"
      "   str?    schedule"
      ") -> Tensor");
  ops.def(
      "machete_prepack_B("
      "   Tensor B,"
      "   ScalarType a_type,"
      "   int b_type,"
      "   ScalarType? group_scales_type"
      ") -> Tensor");
437
  // conditionally compiled so impl registration is in source file
438

439
440
441
  ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
  ops.impl("permute_cols", torch::kCUDA, &permute_cols);

442
  // gptq_marlin Optimized Quantized GEMM for GPTQ.
443
444
445
  ops.def(
      "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
      "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
446
447
      "int b_q_type, "
      "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
448
      "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
449
  // conditionally compiled so impl registration is in source file
450
451

  // gptq_marlin repack from GPTQ.
452
453
454
  ops.def(
      "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
      "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
455
  // conditionally compiled so impl registrations are in source file
456

457
  // awq_marlin repack from AWQ.
458
459
460
  ops.def(
      "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
      "SymInt size_n, int num_bits) -> Tensor");
461
  // conditionally compiled so impl registrations are in source file
462
#endif
463

464
  // Dequantization for GGML.
465
  ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor");
466
467
468
  ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);

  // mmvq kernel for GGML.
469
  ops.def(
470
      "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
471
      "-> Tensor");
472
473
474
  ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);

  // mmq kernel for GGML.
475
476
  ops.def(
      "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
477
478
  ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);

479
#ifndef USE_ROCM
480
  // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
481
482
  ops.def(
      "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
483
484
      "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
      "SymInt size_k) -> Tensor");
485
  // conditionally compiled so impl registration is in source file
486

487
  // marlin_qqq_gemm for QQQ.
488
489
490
  ops.def(
      "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
      "Tensor s_tok, Tensor s_ch, Tensor s_group, "
491
492
      "Tensor! workspace, SymInt size_m, SymInt size_n, "
      "SymInt size_k) -> Tensor");
493
  // conditionally compiled so impl registration is in source file
494

495
  // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
496
  // quantization, as well as bias
497
  ops.def(
498
499
      "cutlass_scaled_mm(Tensor! out, Tensor a,"
      "                  Tensor b, Tensor a_scales,"
500
      "                  Tensor b_scales, Tensor? bias) -> ()");
501
  ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
502

503
504
505
506
507
508
509
510
511
  // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
  // quantization.
  ops.def(
      "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
      "                  Tensor b, Tensor a_scales,"
      "                  Tensor b_scales, Tensor azp_adj,"
      "                  Tensor? azp, Tensor? bias) -> ()");
  ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);

512
513
  // Check if cutlass scaled_mm is supported for CUDA devices of the given
  // capability
514
515
516
  ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
  ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);

517
518
519
520
521
522
523
  // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
  ops.def(
      "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
      "bool");
  ops.impl("cutlass_scaled_mm_supports_block_fp8",
           &cutlass_scaled_mm_supports_fp8);

524
525
526
527
528
529
530
  // Check if cutlass sparse scaled_mm is supported for CUDA devices of the
  // given capability
  ops.def(
      "cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool");
  ops.impl("cutlass_sparse_scaled_mm_supported",
           &cutlass_sparse_scaled_mm_supported);

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
  // CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
  // quantization, as well as bias
  ops.def(
      "cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
      "                         Tensor bt_nzs,"
      "                         Tensor bt_meta, Tensor a_scales,"
      "                         Tensor b_scales, Tensor? bias) -> ()");
  ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);

  // CUTLASS sparse matrix compressor
  ops.def(
      "cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta,"
      "                              Tensor a) -> bool");
  ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);

546
547
548
549
  // Mamba selective scan kernel
  ops.def(
      "selective_scan_fwd(Tensor! u, Tensor! delta,"
      "Tensor! A, Tensor! B, Tensor! C,"
550
      "Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
551
      "bool delta_softplus,"
552
553
554
      "Tensor? query_start_loc,"
      "Tensor? cache_indices,"
      "Tensor? has_initial_state,"
555
556
      "Tensor! ssm_states,"
      "int pad_slot_id) -> ()");
557
558
559
560
561
562
  ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

  ops.def(
      "causal_conv1d_update(Tensor! x,"
      "Tensor! conv_state,"
      "Tensor! weight,"
563
      "Tensor? bias_,"
564
      "bool silu_activation,"
565
      "Tensor? cache_seqlens_,"
566
567
      "Tensor? conv_state_indices,"
      "int pad_slot_id) -> ()");
568
569
570
571
572
  ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);

  ops.def(
      "causal_conv1d_fwd(Tensor! x, Tensor! weight,"
      "Tensor? bias_,"
573
574
575
576
      "Tensor!? conv_states,"
      "Tensor? query_start_loc,"
      "Tensor? cache_indices,"
      "Tensor? has_initial_state,"
577
578
      "bool silu_activation,"
      "int pad_slot_id) -> ()");
579
  ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
580
581
582
#endif

  // Quantized GEMM for GPTQ.
583
584
  // Note: even though the C++ inferred schema is correct for this op, it seems
  // to prevent the meta function registry.
zhuwenwen's avatar
zhuwenwen committed
585
586
587
588
//   ops.def(
//       "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
//       "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
//       "-> Tensor");
589
//   ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
590
591

  // Post processing for GPTQ.
592
593
//   ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
//   ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
594
595

  // Compute FP8 quantized tensor for given scaling factor.
zhuwenwen's avatar
zhuwenwen committed
596
//   ops.def(
zhuwenwen's avatar
zhuwenwen committed
597
598
//       "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
//       "()");
zhuwenwen's avatar
zhuwenwen committed
599
//   ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
600

zhuwenwen's avatar
zhuwenwen committed
601
//   // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
zhuwenwen's avatar
zhuwenwen committed
602
//   ops.def(
zhuwenwen's avatar
zhuwenwen committed
603
604
//       "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
//       "-> "
zhuwenwen's avatar
zhuwenwen committed
605
606
//       "()");
//   ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
607

zhuwenwen's avatar
zhuwenwen committed
608
//   // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
609
//   ops.def(
zhuwenwen's avatar
zhuwenwen committed
610
//       "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
zhuwenwen's avatar
zhuwenwen committed
611
//       "Tensor! scale, Tensor? scale_ub) -> "
612
613
614
//       "()");
//   ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
//            &dynamic_per_token_scaled_fp8_quant);
615

616
617
  // Compute int8 quantized tensor for given scaling factor.
  ops.def(
618
      "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
619
      "Tensor? azp) -> ()");
620
621
622
623
  ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);

  // Compute int8 quantized tensor and scaling factor
  ops.def(
624
      "dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
625
      "Tensor!? azp) -> ()");
626
627
628
629
630
631
632
633
634
635
636
637
638
  ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
           &dynamic_scaled_int8_quant);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  // Cache ops
  // Swap in (out) the cache blocks from src to dst.
  cache_ops.def(
      "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);

  // Copy the cache blocks from src to dst.
  cache_ops.def(
639
640
      "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
      "Tensor block_mapping) -> ()");
641
642
  cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);

643
644
645
646
  cache_ops.def(
      "copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
  cache_ops.impl("copy_blocks_mla", torch::kCUDA, &copy_blocks_mla);

647
648
649
650
651
652
  // Reshape the key and value tensors and cache them.
  cache_ops.def(
      "reshape_and_cache(Tensor key, Tensor value,"
      "                  Tensor! key_cache, Tensor! value_cache,"
      "                  Tensor slot_mapping,"
      "                  str kv_cache_dtype,"
653
      "                  Tensor k_scale, Tensor v_scale) -> ()");
654
655
656
657
658
659
660
661
  cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);

  // Reshape the key and value tensors and cache them.
  cache_ops.def(
      "reshape_and_cache_flash(Tensor key, Tensor value,"
      "                        Tensor! key_cache,"
      "                        Tensor! value_cache,"
      "                        Tensor slot_mapping,"
662
      "                        str kv_cache_dtype,"
663
      "                        Tensor k_scale, Tensor v_scale) -> ()");
664
665
666
  cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
                 &reshape_and_cache_flash);

667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
  // read key and value form kv cache
  cache_ops.def(
      "read_cache(Tensor keys, Tensor values,"
      "                  Tensor[]! key_caches, Tensor[]! value_caches,"
      "                  Tensor slot_mapping,"
      "                  str kv_cache_dtype) -> ()");
  cache_ops.impl("read_cache", torch::kCUDA, &read_cache);

  // write multi-layers key and value to kv cache
  cache_ops.def(
      "write_cache_multi_layers(Tensor keys, Tensor values,"
      "                  Tensor[]! key_caches, Tensor[]! value_caches,"
      "                  Tensor slot_mapping,"
      "                  str kv_cache_dtype) -> ()");
  cache_ops.impl("write_cache_multi_layers", torch::kCUDA, &write_cache_multi_layers);

683
684
685
686
687
688
689
690
691
  // Concat kv_c and k_pe and cache them.
  cache_ops.def(
      "concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
      "                     Tensor! kv_cache,"
      "                     Tensor slot_mapping,"
      "                     str kv_cache_dtype,"
      "                     Tensor scale) -> ()");
  cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);

692
693
  // Convert the key and value cache to fp8 data type.
  cache_ops.def(
694
695
      "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
      "str kv_cache_dtype) -> ()");
696
697
698
699
700
701
702
  cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
  // Cuda utils

  // Gets the specified device attribute.
703
704
  cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
  cuda_utils.impl("get_device_attribute", &get_device_attribute);
705
706

  // Gets the maximum shared memory per block device attribute.
707
708
  cuda_utils.def(
      "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
709
710
711
712
713
714
715
  cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
                  &get_max_shared_memory_per_block_device_attribute);
}

#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
  // Custom all-reduce kernels
716
  custom_ar.def(
717
718
      "init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
      "int rank, bool full_nvlink) -> int");
719
720
  custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
  custom_ar.def(
721
722
723
      "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
      "int reg_buffer_sz_bytes) -> ()");
  custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
724
725
726
727

  custom_ar.def("dispose", &dispose);
  custom_ar.def("meta_size", &meta_size);

728
  custom_ar.def("register_buffer", &register_buffer);
729
730
731
732
733
734
  custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  custom_ar.def("register_graph_buffers", &register_graph_buffers);
}
#endif

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)