torch_bindings.cpp 21.4 KB
Newer Older
1
2
#include "cache.h"
#include "ops.h"
3
#include "core/registration.h"
4
5
6

#include <torch/library.h>

Jiayi Yan's avatar
Jiayi Yan committed
7
// Note: overwrite the external definition for sharing same name between
8
9
10
// libraries use different ISAs.
#define TORCH_EXTENSION_NAME _C

11
12
13
14
15
16
17
18
19
20
21
22
23
void release_dnnl_matmul_handler(int64_t handler);

int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b,
                                        const torch::Tensor& b_scales,
                                        at::ScalarType output_type,
                                        bool dynamic_act_quant, bool use_azp,
                                        int64_t primitive_cache_size);

void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
                      const torch::Tensor& a_scales,
                      const std::optional<torch::Tensor>& azp,
                      const std::optional<torch::Tensor>& azp_adj,
                      const std::optional<torch::Tensor>& bias,
24
                      const torch::Tensor& handler_tensor);
25

26
27
28
29
int64_t create_onednn_mm_handler(const torch::Tensor& b,
                                 int64_t primitive_cache_size);

void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
30
31
               const std::optional<torch::Tensor>& bias,
               const torch::Tensor& handler_tensor);
32

33
34
bool is_onednn_acl_supported();

Thien Tran's avatar
Thien Tran committed
35
36
37
38
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
                        torch::Tensor& kv_cache, double scale,
                        torch::Tensor& block_tables, torch::Tensor& seq_lens);

39
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
40
                         const int64_t rank, const int64_t thread_num);
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

std::string join_shm_manager(int64_t handle, const std::string& name);

void shm_allreduce(int64_t handle, torch::Tensor& data);

void shm_gather(int64_t handle, torch::Tensor& data,
                const std::optional<std::vector<torch::Tensor>>& outputs,
                int64_t dst);

void shm_all_gather(int64_t handle, const torch::Tensor& data,
                    torch::Tensor& output);

void shm_send_tensor_list(int64_t handle,
                          const std::vector<torch::Tensor>& tensor_list,
                          int64_t dst);

std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
                                const std::optional<at::Tensor>& bias,
                                bool is_vnni);

at::Tensor convert_weight_packed(at::Tensor& weight);

at::Tensor fused_experts_cpu(
    at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2,
    at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace,
    bool use_int8_w8a8, bool use_fp8_w8a16,
    const std::optional<at::Tensor>& w1_scale,
    const std::optional<at::Tensor>& w2_scale,
    const std::optional<std::vector<int64_t>> block_size,
    const std::optional<at::Tensor>& a1_scale,
    const std::optional<at::Tensor>& a2_scale, bool is_vnni);

at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
                                     at::Tensor& scales2,
                                     const std::optional<at::Tensor>& bias,
                                     at::ScalarType out_dtype, bool is_vnni);

80
81
82
83
84
85
86
87
// Adapted from sglang: INT4 W4A8 kernels
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
    at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);

at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
                              at::Tensor& w_scales,
                              std::optional<at::Tensor> bias);

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
torch::Tensor get_scheduler_metadata(
    const int64_t num_req, const int64_t num_heads_q,
    const int64_t num_heads_kv, const int64_t head_dim,
    const torch::Tensor& seq_lens, at::ScalarType dtype,
    const torch::Tensor& query_start_loc, const bool casual,
    const int64_t window_size, const std::string& isa_hint,
    const bool enable_kv_split);

void cpu_attn_reshape_and_cache(const torch::Tensor& key,
                                const torch::Tensor& value,
                                torch::Tensor& key_cache,
                                torch::Tensor& value_cache,
                                const torch::Tensor& slot_mapping,
                                const std::string& isa);

void cpu_attention_with_kv_cache(
    const torch::Tensor& query, const torch::Tensor& key_cache,
    const torch::Tensor& value_cache, torch::Tensor& output,
    const torch::Tensor& query_start_loc, const torch::Tensor& seq_lens,
    const double scale, const bool causal,
    const std::optional<torch::Tensor>& alibi_slopes,
    const int64_t sliding_window_left, const int64_t sliding_window_right,
    const torch::Tensor& block_table, const double softcap,
    const torch::Tensor& scheduler_metadata,
    const std::optional<torch::Tensor>& s_aux);

114
115
116
// Note: just for avoiding importing errors
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }

Li, Jiang's avatar
Li, Jiang committed
117
118
119
120
121
122
123
void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
                    torch::Tensor& output, const torch::Tensor& scales,
                    const std::optional<torch::Tensor>& zeros,
                    const std::optional<torch::Tensor>& g_idx,
                    const std::optional<torch::Tensor>& bias,
                    const int64_t pack_factor, const std::string& isa_hint);

124
125
126
127
128
129
130
131
void prepack_moe_weight(const torch::Tensor& weight,
                        torch::Tensor& packed_weight, const std::string& isa);

void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
                   const torch::Tensor& w13, const torch::Tensor& w2,
                   const std::optional<torch::Tensor>& w13_bias,
                   const std::optional<torch::Tensor>& w2_bias,
                   const torch::Tensor& topk_weights,
132
133
                   const torch::Tensor& topk_id, const bool skip_weighted,
                   const std::string& act, const std::string& isa);
134

135
136
137
138
139
140
void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
                                      const torch::Tensor positions,
                                      const torch::Tensor block_table,
                                      torch::Tensor slot_mapping,
                                      const int64_t block_size);

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
namespace cpu_utils {
void eagle_prepare_inputs_padded_kernel_impl(
    const torch::Tensor& cu_num_draft_tokens,
    const torch::Tensor& valid_sampled_tokens_count,
    const torch::Tensor& query_start_loc_gpu,
    torch::Tensor& token_indices_to_sample,
    torch::Tensor& num_rejected_tokens_gpu, const int64_t num_reqs);
void eagle_prepare_next_token_padded_kernel_impl(
    const torch::Tensor& sampled_token_ids,
    const torch::Tensor& discard_request_mask,
    const torch::Tensor& backup_next_token_ids, torch::Tensor& next_token_ids,
    torch::Tensor& valid_sampled_tokens_count, const int64_t vocab_size,
    const int64_t num_sampled_tokens_per_req, const int64_t num_reqs);
void eagle_step_slot_mapping_metadata_kernel_impl(
    const torch::Tensor& positions, const torch::Tensor& block_table,
    torch::Tensor& seq_lens, torch::Tensor& out_clamped_positions,
    torch::Tensor& out_slot_mapping, const int64_t block_size,
    const int64_t max_model_len, const int64_t PAD_ID);
void copy_and_expand_eagle_inputs_kernel_impl(
    const torch::Tensor& target_token_ids,
    const torch::Tensor& target_positions, const torch::Tensor& next_token_ids,
    torch::Tensor& out_input_ids, torch::Tensor& out_positions,
    torch::Tensor& out_is_rejected_token_mask,
    torch::Tensor& out_is_masked_token_mask,
    torch::Tensor& out_new_token_indices,
    torch::Tensor& out_hidden_state_mapping,
    const torch::Tensor& query_start_loc, const torch::Tensor& query_end_loc,
    const int64_t padding_token_id, const int64_t parallel_drafting_token_id,
    const int64_t total_input_tokens,
    const int64_t num_padding_slots_per_request, const bool shift_input_ids);
void rejection_greedy_sample_kernel_impl(
    torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
    const torch::Tensor& draft_token_ids, const torch::Tensor& target_argmax,
    const torch::Tensor& bonus_token_ids,
    const std::optional<torch::Tensor>& is_greedy, const int64_t max_spec_len);
void rejection_random_sample_kernel_impl(
    torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
    const torch::Tensor& draft_token_ids,
    const std::optional<torch::Tensor>& draft_probs,
    const torch::Tensor& target_probs, const torch::Tensor& bonus_token_ids,
    const torch::Tensor& recovered_token_ids,
    const torch::Tensor& uniform_probs,
    const std::optional<torch::Tensor>& is_greedy, const int64_t max_spec_len,
    const int64_t vocab_size, const bool no_draft_probs);
void expand_kernel_impl(torch::Tensor& output, const torch::Tensor& input,
                        const torch::Tensor& cu_num_tokens,
                        const int64_t replace_from, const int64_t replace_to);
void sample_recovered_tokens_kernel_impl(
    torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
    const torch::Tensor& draft_token_ids,
    const std::optional<torch::Tensor>& draft_probs,
    const torch::Tensor& target_probs, const torch::Tensor& inv_q,
    const int64_t vocab_size, const bool no_draft_probs);
}  // namespace cpu_utils

196
197
198
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  // vLLM custom ops

199
200
201
202
203
204
205
206
207
  ops.def(
      "dynamic_4bit_int_moe("
      "Tensor x, Tensor topk_ids, Tensor topk_weights,"
      "Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
      "int group_size, bool apply_router_weight_on_input, int activation_kind"
      ") -> Tensor");

  ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
  // Activation ops

  // Activation function used in SwiGLU.
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);

  // Activation function used in GeGLU with `none` approximation.
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation.
  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);

  // GELU implementation used in GPT-2.
  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_new", torch::kCPU, &gelu_new);

  // Approximate GELU implementation.
  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_fast", torch::kCPU, &gelu_fast);

230
231
232
233
  // Quick GELU implementation.
  ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_quick", torch::kCPU, &gelu_quick);

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
  // Layernorm
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
  ops.def(
      "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
      "()");
  ops.impl("rms_norm", torch::kCPU, &rms_norm);

  // In-place fused Add and RMS Normalization.
  ops.def(
      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);

  // Rotary embedding
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding(Tensor positions, Tensor! query,"
251
      "                 Tensor!? key, int head_size,"
252
253
      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
  ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
254
255

  // Quantization
256
257
258
259
260
261
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
    defined(__powerpc64__)
  // Helper function to release oneDNN handlers
  ops.def("release_dnnl_matmul_handler(int handler) -> ()",
          &release_dnnl_matmul_handler);

262
263
264
265
266
267
268
269
270
  // Create oneDNN GEMM handler
  ops.def(
      "create_onednn_mm_handler(Tensor b, int "
      "primitive_cache_size) -> int",
      &create_onednn_mm_handler);

  // oneDNN GEMM
  ops.def(
      "onednn_mm(Tensor! c, Tensor a, Tensor? bias, "
271
      "Tensor handler_tensor) -> ()");
272
273
  ops.impl("onednn_mm", torch::kCPU, &onednn_mm);

274
275
276
  // Check if oneDNN was built with ACL backend
  ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported);

277
278
279
280
281
282
283
284
285
286
  // Create oneDNN W8A8 handler
  ops.def(
      "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
      "output_type, bool dynamic_act_quant, bool use_azp, int "
      "primitive_cache_size) -> int",
      &create_onednn_scaled_mm_handler);

  // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
  ops.def(
      "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
287
      "Tensor? azp_adj, Tensor? bias, Tensor handler_tensor) -> ()");
288
  ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);
289

290
291
  // Compute int8 quantized tensor for given scaling factor.
  ops.def(
292
      "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
293
      "Tensor? azp) -> ()");
294
  ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
295

296
297
  // Compute int8 quantized tensor and scaling factor
  ops.def(
298
      "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
299
      "Tensor!? azp) -> ()");
300
301
302
  ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
           &dynamic_scaled_int8_quant);
#endif
303
304

// SHM CCL
305
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
306
307
308
309
  ops.def(
      "init_shm_manager(str name, int group_size, int rank, int thread_num) -> "
      "int",
      &init_shm_manager);
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
  ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
  ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
  ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
  ops.def(
      "shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
      "()");
  ops.impl("shm_gather", torch::kCPU, &shm_gather);
  ops.def(
      "shm_all_gather(int handle, Tensor data, Tensor! output) -> "
      "()");
  ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather);
  ops.def(
      "shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
      "()");
  ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
  ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
          &shm_recv_tensor_list);
327
#endif  // #if defined(__AVX512F__) || defined(__aarch64__)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

  // sgl-kernels
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
  ops.def(
      "weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? "
      "bias, bool is_vnni) -> Tensor");
  ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
  ops.def("convert_weight_packed(Tensor! weight) -> Tensor");
  ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
  ops.def(
      "fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor "
      "topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool "
      "use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? "
      "block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> "
      "Tensor");
  ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
  ops.def(
      "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, "
      "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
  ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
           &int8_scaled_mm_with_quant);
349
350
351
352
353
354
355
356
357
358
359
360

  // Adapted from sglang: INT4 W4A8 kernels
  ops.def(
      "convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
      "Tensor scales) -> (Tensor, Tensor, Tensor)");
  ops.impl("convert_weight_packed_scale_zp", torch::kCPU,
           &convert_weight_packed_scale_zp);

  ops.def(
      "int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
      "Tensor(a3!) w_scales, Tensor? bias) -> Tensor");
  ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu);
361
#endif
362

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
  // CPU attention kernels
  ops.def(
      "get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
      "int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
      "query_start_loc, bool casual, int window_size, str isa_hint, bool "
      "enable_kv_split) -> Tensor",
      &get_scheduler_metadata);
  ops.def(
      "cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
      "key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
      "isa) -> ()",
      &cpu_attn_reshape_and_cache);
  ops.def(
      "cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
      "value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
      "seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
      "sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
380
      "float softcap, Tensor scheduler_metadata, Tensor? s_aux) -> ()",
381
      &cpu_attention_with_kv_cache);
382
383
384
385
386

  // placeholders
  ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
  ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
  ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
Li, Jiang's avatar
Li, Jiang committed
387
388
389
390
391
392
393
394
395

  // WNA16
#if defined(__AVX512F__)
  ops.def(
      "cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
      "Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
      "pack_factor, str isa_hint) -> ()");
  ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif
396
397
398
399
400
401
402
403
404
405

  // fused moe
#if defined(__AVX512F__)
  ops.def(
      "prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
      "-> ()");
  ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
  ops.def(
      "cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
      "Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
406
      "bool skip_weighted, "
407
408
409
      "str act, str isa) -> ()");
  ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
410
  ops.def(
Thien Tran's avatar
Thien Tran committed
411
412
413
      "mla_decode_kvcache("
      "   Tensor! out, Tensor query, Tensor kv_cache,"
      "   float scale, Tensor block_tables, Tensor seq_lens) -> ()");
414
  ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
415
416
417
418
419
420

  ops.def(
      "compute_slot_mapping_kernel_impl(Tensor query_start_loc, Tensor "
      "positions, Tensor block_table, Tensor(a3!) slot_mapping, SymInt "
      "block_size) -> ()",
      &compute_slot_mapping_kernel_impl);
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

  // Speculative decoding kernels
  ops.def(
      "eagle_prepare_inputs_padded_kernel_impl(Tensor cu_num_draft_tokens, "
      "Tensor valid_sampled_tokens_count, Tensor query_start_loc_gpu, "
      "Tensor(a3!) token_indices_to_sample, "
      "Tensor(a4!) num_rejected_tokens_gpu, "
      "SymInt num_reqs) -> ()",
      &cpu_utils::eagle_prepare_inputs_padded_kernel_impl);
  ops.def(
      "eagle_prepare_next_token_padded_kernel_impl("
      "Tensor sampled_token_ids, Tensor discard_request_mask, "
      "Tensor backup_next_token_ids, Tensor(a3!) next_token_ids, "
      "Tensor(a4!) valid_sampled_tokens_count, SymInt vocab_size, "
      "SymInt num_sampled_tokens_per_req, SymInt num_reqs) -> ()",
      &cpu_utils::eagle_prepare_next_token_padded_kernel_impl);
  ops.def(
      "eagle_step_slot_mapping_metadata_kernel_impl("
      "Tensor positions, Tensor block_table, Tensor(a2!) seq_lens, "
      "Tensor(a3!) out_clamped_positions, Tensor(a4!) out_slot_mapping, "
      "SymInt block_size, SymInt max_model_len, SymInt PAD_ID) -> ()",
      &cpu_utils::eagle_step_slot_mapping_metadata_kernel_impl);
  ops.def(
      "copy_and_expand_eagle_inputs_kernel_impl("
      "Tensor target_token_ids, Tensor target_positions, "
      "Tensor next_token_ids, Tensor(a3!) out_input_ids, "
      "Tensor(a4!) out_positions, "
      "Tensor(a5!) out_is_rejected_token_mask, "
      "Tensor(a6!) out_is_masked_token_mask, "
      "Tensor(a7!) out_new_token_indices, "
      "Tensor(a8!) out_hidden_state_mapping, "
      "Tensor query_start_loc, Tensor query_end_loc, "
      "SymInt padding_token_id, SymInt parallel_drafting_token_id, "
      "SymInt total_input_tokens, SymInt num_padding_slots_per_request, "
      "bool shift_input_ids) -> ()",
      &cpu_utils::copy_and_expand_eagle_inputs_kernel_impl);
  ops.def(
      "rejection_greedy_sample_kernel_impl("
      "Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
      "Tensor draft_token_ids, Tensor target_argmax, "
      "Tensor bonus_token_ids, Tensor? is_greedy, "
      "SymInt max_spec_len) -> ()",
      &cpu_utils::rejection_greedy_sample_kernel_impl);
  ops.def(
      "rejection_random_sample_kernel_impl("
      "Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
      "Tensor draft_token_ids, Tensor? draft_probs, "
      "Tensor target_probs, Tensor bonus_token_ids, "
      "Tensor recovered_token_ids, Tensor uniform_probs, "
      "Tensor? is_greedy, SymInt max_spec_len, SymInt vocab_size, "
      "bool no_draft_probs) -> ()",
      &cpu_utils::rejection_random_sample_kernel_impl);
  ops.def(
      "expand_kernel_impl(Tensor(a0!) output, Tensor input, "
      "Tensor cu_num_tokens, SymInt replace_from, "
      "SymInt replace_to) -> ()",
      &cpu_utils::expand_kernel_impl);
  ops.def(
      "sample_recovered_tokens_kernel_impl("
      "Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
      "Tensor draft_token_ids, Tensor? draft_probs, "
      "Tensor target_probs, Tensor inv_q, SymInt vocab_size, "
      "bool no_draft_probs) -> ()",
      &cpu_utils::sample_recovered_tokens_kernel_impl);
Thien Tran's avatar
Thien Tran committed
485
486
}

487
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)