torch_bindings.cpp 18 KB
Newer Older
1
2
3
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
4
#include "core/registration.h"
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

#include <torch/library.h>

// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  // vLLM custom ops

  // Attention ops
  // Compute the attention between an input query and the cached
  // keys/values using PagedAttention.
  ops.def(
      "paged_attention_v1("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
30
31
      "    str kv_cache_dtype, float k_scale, float v_scale,"
      "    int tp_rank, int blocksparse_local_blocks,"
32
33
34
35
36
37
38
39
40
41
42
43
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);

  // PagedAttention V2.
  ops.def(
      "paged_attention_v2("
      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
44
45
      "    str kv_cache_dtype, float k_scale, float v_scale,"
      "    int tp_rank, int blocksparse_local_blocks,"
46
47
48
49
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);

zhuwenwen's avatar
zhuwenwen committed
50
51
52
53
54
55
56
57
58
59
60
61
  // Compute the attention between an input query and the cached
  // keys/values using PagedAttention. (opt)
  ops.def(
      "paged_attention_v1_opt("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
      "    str kv_cache_dtype, float k_scale, float v_scale,"
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
62
  ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
zhuwenwen's avatar
zhuwenwen committed
63
64
65
66
67
68
69
70
71
72
73
74
75

  // PagedAttention V2 (opt). 
  ops.def(
      "paged_attention_v2_opt("
      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
      "    str kv_cache_dtype, float k_scale, float v_scale,"
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
76
  ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
zhuwenwen's avatar
zhuwenwen committed
77

78
79
80
81
82
83
84
85
86
87
88
89
90
  // Activation ops
  // Activation function used in SwiGLU.
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

  // Activation function used in GeGLU with `none` approximation.
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation.
  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);

zhuwenwen's avatar
zhuwenwen committed
91
92
93
94
95
96
97
98
99
100
101
102
  // Activation function used in SwiGLU. (opt)
  ops.def("silu_and_mul_opt(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul_opt", torch::kCUDA, &silu_and_mul);

  // Activation function used in GeGLU with `none` approximation. (opt)
  ops.def("gelu_and_mul_opt(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul_opt", torch::kCUDA, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation. (opt)
  ops.def("gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul_opt", torch::kCUDA, &gelu_tanh_and_mul);

103
104
105
106
107
108
109
110
  // GELU implementation used in GPT-2.
  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_new", torch::kCUDA, &gelu_new);

  // Approximate GELU implementation.
  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);

111
112
113
114
  // Quick GELU implementation.
  ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);

115
116
117
118
  // prepare_inputs advance_step
  ops.def("advance_step", &advance_step);
  ops.impl("advance_step", torch::kCUDA, &advance_step);

119
120
121
122
123
124
125
126
127
128
129
130
131
  // Layernorm
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
  ops.def(
      "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
      "()");
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);

  // In-place fused Add and RMS Normalization.
  ops.def(
      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);

zhuwenwen's avatar
zhuwenwen committed
132
133
134
135
136
137
138
139
140
141
142
143
  // Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
  ops.def(
      "rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
      "()");
  ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);

  // In-place fused Add and RMS Normalization. (opt)
  ops.def(
      "fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);

144
145
146
147
148
149
150
151
  // Rotary embedding
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding(Tensor positions, Tensor! query,"
      "                 Tensor! key, int head_size,"
      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
  ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);

huangwb's avatar
huangwb committed
152
153
154
155
156
157
158
159
160
  // Rotary embedding TGI for TGI
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding_tgi(Tensor! query, Tensor! key,"
      "                 int head_size, Tensor cos_cache,"
      "                 Tensor sin_cache, bool is_neox) -> ()");
//   ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
  ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);

161
162
163
164
165
166
167
168
169
170
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
  // (supports multiple loras).
  ops.def(
      "batched_rotary_embedding(Tensor positions, Tensor! query,"
      "                         Tensor! key, int head_size,"
      "                         Tensor cos_sin_cache, bool is_neox,"
      "                         int rot_dim,"
      "                         Tensor cos_sin_cache_offsets) -> ()");
  ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);

zhuwenwen's avatar
zhuwenwen committed
171
172
173
174
  // trans w16
  ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
  ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
  // Quantization ops
#ifndef USE_ROCM
  // Quantized GEMM for AQLM.
  ops.def("aqlm_gemm", &aqlm_gemm);
  ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);

  // Decompression method for AQLM.
  ops.def("aqlm_dequant", &aqlm_dequant);
  ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);

  // Quantized GEMM for AWQ.
  ops.def("awq_gemm", &awq_gemm);
  ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);

  // Dequantization for AWQ.
  ops.def("awq_dequantize", &awq_dequantize);
  ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);

  // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
  ops.def("marlin_gemm", &marlin_gemm);
  ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);

  // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
  ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
  ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
  // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
  ops.def("machete_supported_schedules", &machete::supported_schedules);
  ops.def(
      "machete_gemm(Tensor A, Tensor B,"
      "             __torch__.torch.classes._core_C.ScalarType btype,"
      "             Tensor? scales, Tensor? zeros, int? group_size,"
      "             Tensor? C, float? alpha, float? beta, str? schedule)"
      "-> Tensor");
  ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
  ops.def(
      "machete_prepack_B(Tensor B,"
      "                  __torch__.torch.classes._core_C.ScalarType btype)"
      "-> Tensor");
  ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);

216
217
218
219
220
221
222
223
  // gptq_marlin Optimized Quantized GEMM for GPTQ.
  ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
  ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);

  // gptq_marlin repack from GPTQ.
  ops.def("gptq_marlin_repack", &gptq_marlin_repack);
  ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);

224
225
226
227
  // awq_marlin repack from AWQ.
  ops.def("awq_marlin_repack", &awq_marlin_repack);
  ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);

228
229
230
231
232
233
234
235
236
237
238
239
  // Dequantization for GGML.
  ops.def("ggml_dequantize", &ggml_dequantize);
  ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);

  // mmvq kernel for GGML.
  ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
  ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);

  // mmq kernel for GGML.
  ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
  ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);

240
241
242
243
  // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
  ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
  ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);

244
245
246
247
  // marlin_qqq_gemm for QQQ.
  ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
  ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);

248
  // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
249
  // quantization, as well as bias
250
  ops.def(
251
252
      "cutlass_scaled_mm(Tensor! out, Tensor a,"
      "                  Tensor b, Tensor a_scales,"
253
      "                  Tensor b_scales, Tensor? bias) -> ()");
254
  ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
255

256
257
258
259
260
261
262
263
264
  // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
  // quantization.
  ops.def(
      "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
      "                  Tensor b, Tensor a_scales,"
      "                  Tensor b_scales, Tensor azp_adj,"
      "                  Tensor? azp, Tensor? bias) -> ()");
  ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);

265
266
267
268
269
  // Check if cutlass scaled_mm is supported for CUDA devices of the given
  // capability
  ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
  ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
           &cutlass_scaled_mm_supports_fp8);
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
  // Mamba selective scan kernel
  ops.def(
      "selective_scan_fwd(Tensor! u, Tensor! delta,"
      "Tensor! A, Tensor! B, Tensor! C,"
      "Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
      "bool delta_softplus,"
      "Tensor? index_, Tensor? x) -> Tensor[]");
  ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

  ops.def(
      "causal_conv1d_update(Tensor! x,"
      "Tensor! conv_state,"
      "Tensor! weight,"
      "Tensor? bias_,"
      "bool silu_activation) -> Tensor");
  ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);

  ops.def(
      "causal_conv1d_fwd(Tensor! x, Tensor! weight,"
      "Tensor? bias_,"
      "Tensor? seq_idx_,"
      "Tensor? initial_states_,"
      "Tensor? final_states_out_,"
      "bool silu_activation) -> Tensor");
  ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
295
296
297
#endif

  // Quantized GEMM for GPTQ.
298
299
//   ops.def("gptq_gemm", &gptq_gemm);
//   ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
300
301

  // Post processing for GPTQ.
302
303
//   ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
//   ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
304
305
306
307
308
309
310
311

  // Quantized GEMM for SqueezeLLM.
  ops.def(
      "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
      "lookup_table) -> ()");
  ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);

  // Compute FP8 quantized tensor for given scaling factor.
zhuwenwen's avatar
zhuwenwen committed
312
313
314
//   ops.def(
//       "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
//   ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
315
316


317
  // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
zhuwenwen's avatar
zhuwenwen committed
318
319
320
321
//   ops.def(
//       "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
//       "()");
//   ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
322

323
  // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
324
325
326
327
328
329
//   ops.def(
//       "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
//       "scale, Tensor? scale_ub) -> "
//       "()");
//   ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
//            &dynamic_per_token_scaled_fp8_quant);
330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
  // Aligning the number of tokens to be processed by each expert such
  // that it is divisible by the block size.
  ops.def(
      "moe_align_block_size(Tensor topk_ids, int num_experts,"
      "                     int block_size, Tensor! sorted_token_ids,"
      "                     Tensor! experts_ids,"
      "                     Tensor! num_tokens_post_pad) -> ()");
  ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);

  // Compute int8 quantized tensor for given scaling factor.
  ops.def(
      "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
      "()");
  ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);

  // Compute int8 quantized tensor and scaling factor
  ops.def(
      "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
      "()");
  ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
           &dynamic_scaled_int8_quant);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  // Cache ops
  // Swap in (out) the cache blocks from src to dst.
  cache_ops.def(
      "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);

  // Copy the cache blocks from src to dst.
  cache_ops.def(
      "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
      "block_mapping) -> ()");
  cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);

  // Reshape the key and value tensors and cache them.
  cache_ops.def(
      "reshape_and_cache(Tensor key, Tensor value,"
      "                  Tensor! key_cache, Tensor! value_cache,"
      "                  Tensor slot_mapping,"
      "                  str kv_cache_dtype,"
373
      "                  float k_scale, float v_scale) -> ()");
374
375
376
377
378
379
380
381
  cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);

  // Reshape the key and value tensors and cache them.
  cache_ops.def(
      "reshape_and_cache_flash(Tensor key, Tensor value,"
      "                        Tensor! key_cache,"
      "                        Tensor! value_cache,"
      "                        Tensor slot_mapping,"
382
383
      "                        str kv_cache_dtype,"
      "                        float k_scale, float v_scale) -> ()");
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
  cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
                 &reshape_and_cache_flash);

  // Convert the key and value cache to fp8 data type.
  cache_ops.def(
      "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
      "kv_cache_dtype) -> ()");
  cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
  // Cuda utils

  // Gets the specified device attribute.
  cuda_utils.def("get_device_attribute", &get_device_attribute);
  cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);

  // Gets the maximum shared memory per block device attribute.
  cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
                 &get_max_shared_memory_per_block_device_attribute);
  cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
                  torch::kCUDA,
                  &get_max_shared_memory_per_block_device_attribute);
}

#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
  // Custom all-reduce kernels
  custom_ar.def("init_custom_ar", &init_custom_ar);
  custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);

  custom_ar.def("should_custom_ar", &should_custom_ar);
  custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);

  custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
  custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);

  custom_ar.def(
      "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
      "()");
  custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);

  custom_ar.def("dispose", &dispose);
  custom_ar.impl("dispose", torch::kCPU, &dispose);

  custom_ar.def("meta_size", &meta_size);
  custom_ar.impl("meta_size", torch::kCPU, &meta_size);

  custom_ar.def("register_buffer", &register_buffer);
  custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);

  custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
                 &get_graph_buffer_ipc_meta);

  custom_ar.def("register_graph_buffers", &register_graph_buffers);
  custom_ar.impl("register_graph_buffers", torch::kCPU,
                 &register_graph_buffers);
}
#endif

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)