torch_bindings.cpp 15.6 KB
Newer Older
1
2
3
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
4
#include "core/registration.h"
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

#include <torch/library.h>

// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  // vLLM custom ops

  // Attention ops
  // Compute the attention between an input query and the cached
  // keys/values using PagedAttention.
  ops.def(
      "paged_attention_v1("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
30
31
      "    str kv_cache_dtype, float k_scale, float v_scale,"
      "    int tp_rank, int blocksparse_local_blocks,"
32
33
34
35
36
37
38
39
40
41
42
43
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);

  // PagedAttention V2.
  ops.def(
      "paged_attention_v2("
      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
44
45
      "    str kv_cache_dtype, float k_scale, float v_scale,"
      "    int tp_rank, int blocksparse_local_blocks,"
46
47
48
49
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);

zhuwenwen's avatar
zhuwenwen committed
50
51
52
53
54
55
56
57
58
59
60
61
  // Compute the attention between an input query and the cached
  // keys/values using PagedAttention. (opt)
  ops.def(
      "paged_attention_v1_opt("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
      "    str kv_cache_dtype, float k_scale, float v_scale,"
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
62
  ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
zhuwenwen's avatar
zhuwenwen committed
63
64
65
66
67
68
69
70
71
72
73
74
75

  // PagedAttention V2 (opt). 
  ops.def(
      "paged_attention_v2_opt("
      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
      "    str kv_cache_dtype, float k_scale, float v_scale,"
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
76
  ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
zhuwenwen's avatar
zhuwenwen committed
77

78
79
80
81
82
83
84
85
86
87
88
89
90
  // Activation ops
  // Activation function used in SwiGLU.
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

  // Activation function used in GeGLU with `none` approximation.
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation.
  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);

zhuwenwen's avatar
zhuwenwen committed
91
92
93
94
95
96
97
98
99
100
101
102
  // Activation function used in SwiGLU. (opt)
  ops.def("silu_and_mul_opt(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul_opt", torch::kCUDA, &silu_and_mul);

  // Activation function used in GeGLU with `none` approximation. (opt)
  ops.def("gelu_and_mul_opt(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul_opt", torch::kCUDA, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation. (opt)
  ops.def("gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul_opt", torch::kCUDA, &gelu_tanh_and_mul);

103
104
105
106
107
108
109
110
  // GELU implementation used in GPT-2.
  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_new", torch::kCUDA, &gelu_new);

  // Approximate GELU implementation.
  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);

111
112
113
114
  // Quick GELU implementation.
  ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);

115
116
117
118
  // prepare_inputs advance_step
  ops.def("advance_step", &advance_step);
  ops.impl("advance_step", torch::kCUDA, &advance_step);

119
120
121
122
123
124
125
126
127
128
129
130
131
  // Layernorm
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
  ops.def(
      "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
      "()");
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);

  // In-place fused Add and RMS Normalization.
  ops.def(
      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);

zhuwenwen's avatar
zhuwenwen committed
132
133
134
135
136
137
138
139
140
141
142
143
  // Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
  ops.def(
      "rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
      "()");
  ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);

  // In-place fused Add and RMS Normalization. (opt)
  ops.def(
      "fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);

144
145
146
147
148
149
150
151
  // Rotary embedding
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding(Tensor positions, Tensor! query,"
      "                 Tensor! key, int head_size,"
      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
  ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);

huangwb's avatar
huangwb committed
152
153
154
155
156
157
158
159
160
  // Rotary embedding TGI for TGI
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding_tgi(Tensor! query, Tensor! key,"
      "                 int head_size, Tensor cos_cache,"
      "                 Tensor sin_cache, bool is_neox) -> ()");
//   ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
  ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);

161
162
163
164
165
166
167
168
169
170
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
  // (supports multiple loras).
  ops.def(
      "batched_rotary_embedding(Tensor positions, Tensor! query,"
      "                         Tensor! key, int head_size,"
      "                         Tensor cos_sin_cache, bool is_neox,"
      "                         int rot_dim,"
      "                         Tensor cos_sin_cache_offsets) -> ()");
  ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);

zhuwenwen's avatar
zhuwenwen committed
171
172
173
174
  // trans w16
  ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
  ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
  // Quantization ops
#ifndef USE_ROCM
  // Quantized GEMM for AQLM.
  ops.def("aqlm_gemm", &aqlm_gemm);
  ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);

  // Decompression method for AQLM.
  ops.def("aqlm_dequant", &aqlm_dequant);
  ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);

  // Quantized GEMM for AWQ.
  ops.def("awq_gemm", &awq_gemm);
  ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);

  // Dequantization for AWQ.
  ops.def("awq_dequantize", &awq_dequantize);
  ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);

  // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
  ops.def("marlin_gemm", &marlin_gemm);
  ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);

  // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
  ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
  ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);

  // gptq_marlin Optimized Quantized GEMM for GPTQ.
  ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
  ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);

  // gptq_marlin repack from GPTQ.
  ops.def("gptq_marlin_repack", &gptq_marlin_repack);
  ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);

209
210
211
212
  // awq_marlin repack from AWQ.
  ops.def("awq_marlin_repack", &awq_marlin_repack);
  ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);

213
214
215
216
  // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
  ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
  ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);

217
218
219
220
  // marlin_qqq_gemm for QQQ.
  ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
  ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);

221
222
223
  // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
  // quantization.
  ops.def(
224
225
      "cutlass_scaled_mm(Tensor! out, Tensor a,"
      "                  Tensor b, Tensor a_scales,"
226
      "                  Tensor b_scales, Tensor? bias) -> ()");
227
  ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
228
229
230
231
232
233

  // Check if cutlass scaled_mm is supported for CUDA devices of the given
  // capability
  ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
  ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
           &cutlass_scaled_mm_supports_fp8);
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
#endif

  // Quantized GEMM for GPTQ.
  ops.def("gptq_gemm", &gptq_gemm);
  ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);

  // Post processing for GPTQ.
  ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
  ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);

  // Quantized GEMM for SqueezeLLM.
  ops.def(
      "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
      "lookup_table) -> ()");
  ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);

  // Compute FP8 quantized tensor for given scaling factor.
zhuwenwen's avatar
zhuwenwen committed
251
252
253
//   ops.def(
//       "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
//   ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
254
255


256
  // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
zhuwenwen's avatar
zhuwenwen committed
257
258
259
260
//   ops.def(
//       "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
//       "()");
//   ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
261

262
  // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
263
264
265
266
267
268
//   ops.def(
//       "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
//       "scale, Tensor? scale_ub) -> "
//       "()");
//   ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
//            &dynamic_per_token_scaled_fp8_quant);
269

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
  // Aligning the number of tokens to be processed by each expert such
  // that it is divisible by the block size.
  ops.def(
      "moe_align_block_size(Tensor topk_ids, int num_experts,"
      "                     int block_size, Tensor! sorted_token_ids,"
      "                     Tensor! experts_ids,"
      "                     Tensor! num_tokens_post_pad) -> ()");
  ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);

  // Compute int8 quantized tensor for given scaling factor.
  ops.def(
      "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
      "()");
  ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);

  // Compute int8 quantized tensor and scaling factor
  ops.def(
      "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
      "()");
  ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
           &dynamic_scaled_int8_quant);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  // Cache ops
  // Swap in (out) the cache blocks from src to dst.
  cache_ops.def(
      "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);

  // Copy the cache blocks from src to dst.
  cache_ops.def(
      "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
      "block_mapping) -> ()");
  cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);

  // Reshape the key and value tensors and cache them.
  cache_ops.def(
      "reshape_and_cache(Tensor key, Tensor value,"
      "                  Tensor! key_cache, Tensor! value_cache,"
      "                  Tensor slot_mapping,"
      "                  str kv_cache_dtype,"
312
      "                  float k_scale, float v_scale) -> ()");
313
314
315
316
317
318
319
320
  cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);

  // Reshape the key and value tensors and cache them.
  cache_ops.def(
      "reshape_and_cache_flash(Tensor key, Tensor value,"
      "                        Tensor! key_cache,"
      "                        Tensor! value_cache,"
      "                        Tensor slot_mapping,"
321
322
      "                        str kv_cache_dtype,"
      "                        float k_scale, float v_scale) -> ()");
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
  cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
                 &reshape_and_cache_flash);

  // Convert the key and value cache to fp8 data type.
  cache_ops.def(
      "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
      "kv_cache_dtype) -> ()");
  cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
  // Cuda utils

  // Gets the specified device attribute.
  cuda_utils.def("get_device_attribute", &get_device_attribute);
  cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);

  // Gets the maximum shared memory per block device attribute.
  cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
                 &get_max_shared_memory_per_block_device_attribute);
  cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
                  torch::kCUDA,
                  &get_max_shared_memory_per_block_device_attribute);
}

#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
  // Custom all-reduce kernels
  custom_ar.def("init_custom_ar", &init_custom_ar);
  custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);

  custom_ar.def("should_custom_ar", &should_custom_ar);
  custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);

  custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
  custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);

  custom_ar.def(
      "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
      "()");
  custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);

  custom_ar.def("dispose", &dispose);
  custom_ar.impl("dispose", torch::kCPU, &dispose);

  custom_ar.def("meta_size", &meta_size);
  custom_ar.impl("meta_size", torch::kCPU, &meta_size);

  custom_ar.def("register_buffer", &register_buffer);
  custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);

  custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
                 &get_graph_buffer_ipc_meta);

  custom_ar.def("register_graph_buffers", &register_graph_buffers);
  custom_ar.impl("register_graph_buffers", torch::kCPU,
                 &register_graph_buffers);
}
#endif

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)