torch_bindings.cpp 21.9 KB
Newer Older
1
2
#include "cache.h"
#include "ops.h"
3
#include "core/registration.h"
4
5
6

#include <torch/library.h>

Jiayi Yan's avatar
Jiayi Yan committed
7
// Note: overwrite the external definition for sharing same name between
8
9
10
// libraries use different ISAs.
#define TORCH_EXTENSION_NAME _C

11
12
13
14
15
16
17
18
19
20
21
22
23
void release_dnnl_matmul_handler(int64_t handler);

int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b,
                                        const torch::Tensor& b_scales,
                                        at::ScalarType output_type,
                                        bool dynamic_act_quant, bool use_azp,
                                        int64_t primitive_cache_size);

void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
                      const torch::Tensor& a_scales,
                      const std::optional<torch::Tensor>& azp,
                      const std::optional<torch::Tensor>& azp_adj,
                      const std::optional<torch::Tensor>& bias,
24
                      const torch::Tensor& handler_tensor);
25

26
27
28
29
int64_t create_onednn_mm_handler(const torch::Tensor& b,
                                 int64_t primitive_cache_size);

void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
30
31
               const std::optional<torch::Tensor>& bias,
               const torch::Tensor& handler_tensor);
32

33
34
bool is_onednn_acl_supported();

Thien Tran's avatar
Thien Tran committed
35
36
37
38
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
                        torch::Tensor& kv_cache, double scale,
                        torch::Tensor& block_tables, torch::Tensor& seq_lens);

39
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
40
                         const int64_t rank, const int64_t thread_num);
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

std::string join_shm_manager(int64_t handle, const std::string& name);

void shm_allreduce(int64_t handle, torch::Tensor& data);

void shm_gather(int64_t handle, torch::Tensor& data,
                const std::optional<std::vector<torch::Tensor>>& outputs,
                int64_t dst);

void shm_all_gather(int64_t handle, const torch::Tensor& data,
                    torch::Tensor& output);

void shm_send_tensor_list(int64_t handle,
                          const std::vector<torch::Tensor>& tensor_list,
                          int64_t dst);

std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
                                const std::optional<at::Tensor>& bias,
                                bool is_vnni);

at::Tensor convert_weight_packed(at::Tensor& weight);

at::Tensor fused_experts_cpu(
    at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2,
    at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace,
    bool use_int8_w8a8, bool use_fp8_w8a16,
    const std::optional<at::Tensor>& w1_scale,
    const std::optional<at::Tensor>& w2_scale,
    const std::optional<std::vector<int64_t>> block_size,
    const std::optional<at::Tensor>& a1_scale,
    const std::optional<at::Tensor>& a2_scale, bool is_vnni);

at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
                                     at::Tensor& scales2,
                                     const std::optional<at::Tensor>& bias,
                                     at::ScalarType out_dtype, bool is_vnni);

80
81
82
83
84
85
86
87
// Adapted from sglang: INT4 W4A8 kernels
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
    at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);

at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
                              at::Tensor& w_scales,
                              std::optional<at::Tensor> bias);

88
89
90
void activation_lut_bf16(torch::Tensor& out, torch::Tensor& input,
                         const std::string& activation);

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
torch::Tensor get_scheduler_metadata(
    const int64_t num_req, const int64_t num_heads_q,
    const int64_t num_heads_kv, const int64_t head_dim,
    const torch::Tensor& seq_lens, at::ScalarType dtype,
    const torch::Tensor& query_start_loc, const bool casual,
    const int64_t window_size, const std::string& isa_hint,
    const bool enable_kv_split);

void cpu_attn_reshape_and_cache(const torch::Tensor& key,
                                const torch::Tensor& value,
                                torch::Tensor& key_cache,
                                torch::Tensor& value_cache,
                                const torch::Tensor& slot_mapping,
                                const std::string& isa);

void cpu_attention_with_kv_cache(
    const torch::Tensor& query, const torch::Tensor& key_cache,
    const torch::Tensor& value_cache, torch::Tensor& output,
    const torch::Tensor& query_start_loc, const torch::Tensor& seq_lens,
    const double scale, const bool causal,
    const std::optional<torch::Tensor>& alibi_slopes,
    const int64_t sliding_window_left, const int64_t sliding_window_right,
    const torch::Tensor& block_table, const double softcap,
    const torch::Tensor& scheduler_metadata,
    const std::optional<torch::Tensor>& s_aux);

117
118
119
// Note: just for avoiding importing errors
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }

Li, Jiang's avatar
Li, Jiang committed
120
121
122
123
124
125
126
void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
                    torch::Tensor& output, const torch::Tensor& scales,
                    const std::optional<torch::Tensor>& zeros,
                    const std::optional<torch::Tensor>& g_idx,
                    const std::optional<torch::Tensor>& bias,
                    const int64_t pack_factor, const std::string& isa_hint);

127
128
129
130
131
132
133
134
void prepack_moe_weight(const torch::Tensor& weight,
                        torch::Tensor& packed_weight, const std::string& isa);

void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
                   const torch::Tensor& w13, const torch::Tensor& w2,
                   const std::optional<torch::Tensor>& w13_bias,
                   const std::optional<torch::Tensor>& w2_bias,
                   const torch::Tensor& topk_weights,
135
136
                   const torch::Tensor& topk_id, const bool skip_weighted,
                   const std::string& act, const std::string& isa);
137

138
139
140
141
142
143
void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
                                      const torch::Tensor positions,
                                      const torch::Tensor block_table,
                                      torch::Tensor slot_mapping,
                                      const int64_t block_size);

144
145
void init_cpu_memory_env(std::vector<int64_t> node_ids);

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
namespace cpu_utils {
void eagle_prepare_inputs_padded_kernel_impl(
    const torch::Tensor& cu_num_draft_tokens,
    const torch::Tensor& valid_sampled_tokens_count,
    const torch::Tensor& query_start_loc_gpu,
    torch::Tensor& token_indices_to_sample,
    torch::Tensor& num_rejected_tokens_gpu, const int64_t num_reqs);
void eagle_prepare_next_token_padded_kernel_impl(
    const torch::Tensor& sampled_token_ids,
    const torch::Tensor& discard_request_mask,
    const torch::Tensor& backup_next_token_ids, torch::Tensor& next_token_ids,
    torch::Tensor& valid_sampled_tokens_count, const int64_t vocab_size,
    const int64_t num_sampled_tokens_per_req, const int64_t num_reqs);
void eagle_step_slot_mapping_metadata_kernel_impl(
    const torch::Tensor& positions, const torch::Tensor& block_table,
    torch::Tensor& seq_lens, torch::Tensor& out_clamped_positions,
    torch::Tensor& out_slot_mapping, const int64_t block_size,
    const int64_t max_model_len, const int64_t PAD_ID);
void copy_and_expand_eagle_inputs_kernel_impl(
    const torch::Tensor& target_token_ids,
    const torch::Tensor& target_positions, const torch::Tensor& next_token_ids,
    torch::Tensor& out_input_ids, torch::Tensor& out_positions,
    torch::Tensor& out_is_rejected_token_mask,
    torch::Tensor& out_is_masked_token_mask,
    torch::Tensor& out_new_token_indices,
    torch::Tensor& out_hidden_state_mapping,
    const torch::Tensor& query_start_loc, const torch::Tensor& query_end_loc,
    const int64_t padding_token_id, const int64_t parallel_drafting_token_id,
    const int64_t total_input_tokens,
    const int64_t num_padding_slots_per_request, const bool shift_input_ids);
void rejection_greedy_sample_kernel_impl(
    torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
    const torch::Tensor& draft_token_ids, const torch::Tensor& target_argmax,
    const torch::Tensor& bonus_token_ids,
    const std::optional<torch::Tensor>& is_greedy, const int64_t max_spec_len);
void rejection_random_sample_kernel_impl(
    torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
    const torch::Tensor& draft_token_ids,
    const std::optional<torch::Tensor>& draft_probs,
    const torch::Tensor& target_probs, const torch::Tensor& bonus_token_ids,
    const torch::Tensor& recovered_token_ids,
    const torch::Tensor& uniform_probs,
    const std::optional<torch::Tensor>& is_greedy, const int64_t max_spec_len,
    const int64_t vocab_size, const bool no_draft_probs);
void expand_kernel_impl(torch::Tensor& output, const torch::Tensor& input,
                        const torch::Tensor& cu_num_tokens,
                        const int64_t replace_from, const int64_t replace_to);
void sample_recovered_tokens_kernel_impl(
    torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
    const torch::Tensor& draft_token_ids,
    const std::optional<torch::Tensor>& draft_probs,
    const torch::Tensor& target_probs, const torch::Tensor& inv_q,
    const int64_t vocab_size, const bool no_draft_probs);
}  // namespace cpu_utils

201
202
203
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  // vLLM custom ops

204
205
206
207
208
209
210
211
212
  ops.def(
      "dynamic_4bit_int_moe("
      "Tensor x, Tensor topk_ids, Tensor topk_weights,"
      "Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
      "int group_size, bool apply_router_weight_on_input, int activation_kind"
      ") -> Tensor");

  ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
  // Activation ops

  // Activation function used in SwiGLU.
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);

  // Activation function used in GeGLU with `none` approximation.
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation.
  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);

  // GELU implementation used in GPT-2.
  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_new", torch::kCPU, &gelu_new);

  // Approximate GELU implementation.
  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_fast", torch::kCPU, &gelu_fast);

235
236
237
238
  // Quick GELU implementation.
  ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_quick", torch::kCPU, &gelu_quick);

239
240
241
242
243
244
245
246
247
#if (defined(__aarch64__) && !defined(__APPLE__))

  ops.def(
      "activation_lut_bf16(Tensor! out, Tensor input, str activation)"
      " -> ()");
  ops.impl("activation_lut_bf16", torch::kCPU, &activation_lut_bf16);

#endif  // (defined(__aarch64__) && !defined(__APPLE__))

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
  // Layernorm
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
  ops.def(
      "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
      "()");
  ops.impl("rms_norm", torch::kCPU, &rms_norm);

  // In-place fused Add and RMS Normalization.
  ops.def(
      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);

  // Rotary embedding
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding(Tensor positions, Tensor! query,"
265
      "                 Tensor!? key, int head_size,"
266
267
      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
  ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
268
269

  // Quantization
270
271
272
273
274
275
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
    defined(__powerpc64__)
  // Helper function to release oneDNN handlers
  ops.def("release_dnnl_matmul_handler(int handler) -> ()",
          &release_dnnl_matmul_handler);

276
277
278
279
280
281
282
283
284
  // Create oneDNN GEMM handler
  ops.def(
      "create_onednn_mm_handler(Tensor b, int "
      "primitive_cache_size) -> int",
      &create_onednn_mm_handler);

  // oneDNN GEMM
  ops.def(
      "onednn_mm(Tensor! c, Tensor a, Tensor? bias, "
285
      "Tensor handler_tensor) -> ()");
286
287
  ops.impl("onednn_mm", torch::kCPU, &onednn_mm);

288
289
290
  // Check if oneDNN was built with ACL backend
  ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported);

291
292
293
294
295
296
297
298
299
300
  // Create oneDNN W8A8 handler
  ops.def(
      "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
      "output_type, bool dynamic_act_quant, bool use_azp, int "
      "primitive_cache_size) -> int",
      &create_onednn_scaled_mm_handler);

  // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
  ops.def(
      "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
301
      "Tensor? azp_adj, Tensor? bias, Tensor handler_tensor) -> ()");
302
  ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);
303

304
305
  // Compute int8 quantized tensor for given scaling factor.
  ops.def(
306
      "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
307
      "Tensor? azp) -> ()");
308
  ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
309

310
311
  // Compute int8 quantized tensor and scaling factor
  ops.def(
312
      "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
313
      "Tensor!? azp) -> ()");
314
315
316
  ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
           &dynamic_scaled_int8_quant);
#endif
317
318

// SHM CCL
319
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
320
321
322
323
  ops.def(
      "init_shm_manager(str name, int group_size, int rank, int thread_num) -> "
      "int",
      &init_shm_manager);
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
  ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
  ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
  ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
  ops.def(
      "shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
      "()");
  ops.impl("shm_gather", torch::kCPU, &shm_gather);
  ops.def(
      "shm_all_gather(int handle, Tensor data, Tensor! output) -> "
      "()");
  ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather);
  ops.def(
      "shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
      "()");
  ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
  ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
          &shm_recv_tensor_list);
341
#endif  // #if defined(__AVX512F__) || defined(__aarch64__)
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362

  // sgl-kernels
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
  ops.def(
      "weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? "
      "bias, bool is_vnni) -> Tensor");
  ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
  ops.def("convert_weight_packed(Tensor! weight) -> Tensor");
  ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
  ops.def(
      "fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor "
      "topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool "
      "use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? "
      "block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> "
      "Tensor");
  ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
  ops.def(
      "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, "
      "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
  ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
           &int8_scaled_mm_with_quant);
363
364
365
366
367
368
369
370
371
372
373
374

  // Adapted from sglang: INT4 W4A8 kernels
  ops.def(
      "convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
      "Tensor scales) -> (Tensor, Tensor, Tensor)");
  ops.impl("convert_weight_packed_scale_zp", torch::kCPU,
           &convert_weight_packed_scale_zp);

  ops.def(
      "int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
      "Tensor(a3!) w_scales, Tensor? bias) -> Tensor");
  ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu);
375
#endif
376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
  // CPU attention kernels
  ops.def(
      "get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
      "int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
      "query_start_loc, bool casual, int window_size, str isa_hint, bool "
      "enable_kv_split) -> Tensor",
      &get_scheduler_metadata);
  ops.def(
      "cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
      "key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
      "isa) -> ()",
      &cpu_attn_reshape_and_cache);
  ops.def(
      "cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
      "value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
      "seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
      "sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
394
      "float softcap, Tensor scheduler_metadata, Tensor? s_aux) -> ()",
395
      &cpu_attention_with_kv_cache);
396
397
398
399
400

  // placeholders
  ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
  ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
  ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
Li, Jiang's avatar
Li, Jiang committed
401
402
403
404
405
406
407
408
409

  // WNA16
#if defined(__AVX512F__)
  ops.def(
      "cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
      "Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
      "pack_factor, str isa_hint) -> ()");
  ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif
410
411
412
413
414
415
416
417
418
419

  // fused moe
#if defined(__AVX512F__)
  ops.def(
      "prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
      "-> ()");
  ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
  ops.def(
      "cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
      "Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
420
      "bool skip_weighted, "
421
422
423
      "str act, str isa) -> ()");
  ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
424
  ops.def(
Thien Tran's avatar
Thien Tran committed
425
426
427
      "mla_decode_kvcache("
      "   Tensor! out, Tensor query, Tensor kv_cache,"
      "   float scale, Tensor block_tables, Tensor seq_lens) -> ()");
428
  ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
429
430
431
432
433
434

  ops.def(
      "compute_slot_mapping_kernel_impl(Tensor query_start_loc, Tensor "
      "positions, Tensor block_table, Tensor(a3!) slot_mapping, SymInt "
      "block_size) -> ()",
      &compute_slot_mapping_kernel_impl);
435

436
437
  ops.def("init_cpu_memory_env(SymInt[] node_ids) -> ()", &init_cpu_memory_env);

438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
  // Speculative decoding kernels
  ops.def(
      "eagle_prepare_inputs_padded_kernel_impl(Tensor cu_num_draft_tokens, "
      "Tensor valid_sampled_tokens_count, Tensor query_start_loc_gpu, "
      "Tensor(a3!) token_indices_to_sample, "
      "Tensor(a4!) num_rejected_tokens_gpu, "
      "SymInt num_reqs) -> ()",
      &cpu_utils::eagle_prepare_inputs_padded_kernel_impl);
  ops.def(
      "eagle_prepare_next_token_padded_kernel_impl("
      "Tensor sampled_token_ids, Tensor discard_request_mask, "
      "Tensor backup_next_token_ids, Tensor(a3!) next_token_ids, "
      "Tensor(a4!) valid_sampled_tokens_count, SymInt vocab_size, "
      "SymInt num_sampled_tokens_per_req, SymInt num_reqs) -> ()",
      &cpu_utils::eagle_prepare_next_token_padded_kernel_impl);
  ops.def(
      "eagle_step_slot_mapping_metadata_kernel_impl("
      "Tensor positions, Tensor block_table, Tensor(a2!) seq_lens, "
      "Tensor(a3!) out_clamped_positions, Tensor(a4!) out_slot_mapping, "
      "SymInt block_size, SymInt max_model_len, SymInt PAD_ID) -> ()",
      &cpu_utils::eagle_step_slot_mapping_metadata_kernel_impl);
  ops.def(
      "copy_and_expand_eagle_inputs_kernel_impl("
      "Tensor target_token_ids, Tensor target_positions, "
      "Tensor next_token_ids, Tensor(a3!) out_input_ids, "
      "Tensor(a4!) out_positions, "
      "Tensor(a5!) out_is_rejected_token_mask, "
      "Tensor(a6!) out_is_masked_token_mask, "
      "Tensor(a7!) out_new_token_indices, "
      "Tensor(a8!) out_hidden_state_mapping, "
      "Tensor query_start_loc, Tensor query_end_loc, "
      "SymInt padding_token_id, SymInt parallel_drafting_token_id, "
      "SymInt total_input_tokens, SymInt num_padding_slots_per_request, "
      "bool shift_input_ids) -> ()",
      &cpu_utils::copy_and_expand_eagle_inputs_kernel_impl);
  ops.def(
      "rejection_greedy_sample_kernel_impl("
      "Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
      "Tensor draft_token_ids, Tensor target_argmax, "
      "Tensor bonus_token_ids, Tensor? is_greedy, "
      "SymInt max_spec_len) -> ()",
      &cpu_utils::rejection_greedy_sample_kernel_impl);
  ops.def(
      "rejection_random_sample_kernel_impl("
      "Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
      "Tensor draft_token_ids, Tensor? draft_probs, "
      "Tensor target_probs, Tensor bonus_token_ids, "
      "Tensor recovered_token_ids, Tensor uniform_probs, "
      "Tensor? is_greedy, SymInt max_spec_len, SymInt vocab_size, "
      "bool no_draft_probs) -> ()",
      &cpu_utils::rejection_random_sample_kernel_impl);
  ops.def(
      "expand_kernel_impl(Tensor(a0!) output, Tensor input, "
      "Tensor cu_num_tokens, SymInt replace_from, "
      "SymInt replace_to) -> ()",
      &cpu_utils::expand_kernel_impl);
  ops.def(
      "sample_recovered_tokens_kernel_impl("
      "Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
      "Tensor draft_token_ids, Tensor? draft_probs, "
      "Tensor target_probs, Tensor inv_q, SymInt vocab_size, "
      "bool no_draft_probs) -> ()",
      &cpu_utils::sample_recovered_tokens_kernel_impl);
Thien Tran's avatar
Thien Tran committed
501
502
}

503
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)