rocm_ops.hpp 124 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
// SPDX-License-Identifier: MIT

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// ---- aiter_tensor_t / stream pybind infrastructure (ported from aiter-github) ----
#include "aiter_tensor.h"
#include "aiter_stream.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;

#ifndef AITER_SET_STREAM_PYBIND
#define AITER_SET_STREAM_PYBIND                                                                \
    m.def("_set_current_hip_stream",                                                           \
          [](int64_t stream_ptr) {                                                             \
              aiter::setCurrentHIPStream((hipStream_t)stream_ptr);                             \
          },                                                                                   \
          pybind11::arg("stream_ptr"));
#endif

#ifndef AITER_CORE_PYBIND
#define AITER_CORE_PYBIND                                                                      \
    pybind11::enum_<QuantType>(m, "QuantType")                                                 \
        .value("No", QuantType::No)                                                            \
        .value("per_Tensor", QuantType::per_Tensor)                                            \
        .value("per_Token", QuantType::per_Token)                                              \
        .value("per_1x32", QuantType::per_1x32)                                               \
        .value("per_1x128", QuantType::per_1x128)                                             \
        .value("per_128x128", QuantType::per_128x128)                                         \
        .value("per_256x128", QuantType::per_256x128)                                         \
        .value("per_1024x128", QuantType::per_1024x128)                                       \
        .export_values();                                                                      \
    pybind11::enum_<ActivationType>(m, "ActivationType")                                       \
        .value("No", ActivationType::No)                                                       \
        .value("Silu", ActivationType::Silu)                                                   \
        .value("Gelu", ActivationType::Gelu)                                                   \
        .value("Swiglu", ActivationType::Swiglu)                                              \
        .export_values();                                                                      \
    pybind11::implicitly_convertible<int, QuantType>();                                        \
    pybind11::implicitly_convertible<int, ActivationType>();                                   \
    AITER_SET_STREAM_PYBIND                                                                    \
    pybind11::class_<aiter_tensor_t>(m, "aiter_tensor_t")                                     \
        .def(pybind11::init<>())                                                               \
        .def(pybind11::init([](int64_t data_ptr, size_t numel, int ndim,                      \
                               const std::vector<int64_t>& shape,                             \
                               const std::vector<int64_t>& strides,                           \
                               int dtype, int device_id) {                                    \
                 aiter_tensor_t at{};                                                          \
                 at.ptr = (void*)data_ptr;                                                     \
                 at.numel_ = numel;                                                            \
                 at.ndim = ndim;                                                               \
                 for(int i = 0; i < ndim && i < 8; i++) {                                     \
                     at.shape[i] = shape[i];                                                   \
                     at.strides[i] = strides[i];                                               \
                 }                                                                             \
                 at.dtype_ = (AiterDtype)dtype;                                                \
                 at.device_id = device_id;                                                     \
                 return at;                                                                    \
             }),                                                                               \
             pybind11::arg("data_ptr"),                                                        \
             pybind11::arg("numel"),                                                           \
             pybind11::arg("ndim"),                                                            \
             pybind11::arg("shape"),                                                           \
             pybind11::arg("strides"),                                                         \
             pybind11::arg("dtype"),                                                           \
             pybind11::arg("device_id"))                                                       \
        .def_readwrite("numel_", &aiter_tensor_t::numel_)                                     \
        .def_readwrite("ndim", &aiter_tensor_t::ndim)                                         \
        .def_readwrite("device_id", &aiter_tensor_t::device_id);
#endif

// Registers only aiter_tensor_t and _set_current_hip_stream — no enum registrations.
// Use this in modules that already have QuantType/ActivationType (e.g. via module_aiter_enum).
#ifndef AITER_TENSOR_PYBIND
#define AITER_TENSOR_PYBIND                                                                    \
    AITER_SET_STREAM_PYBIND                                                                    \
    pybind11::class_<aiter_tensor_t>(m, "aiter_tensor_t")                                     \
        .def(pybind11::init<>())                                                               \
        .def(pybind11::init([](int64_t data_ptr, size_t numel, int ndim,                      \
                               const std::vector<int64_t>& shape,                             \
                               const std::vector<int64_t>& strides,                           \
                               int dtype, int device_id) {                                    \
                 aiter_tensor_t at{};                                                          \
                 at.ptr = (void*)data_ptr;                                                     \
                 at.numel_ = numel;                                                            \
                 at.ndim = ndim;                                                               \
                 for(int i = 0; i < ndim && i < 8; i++) {                                     \
                     at.shape[i] = shape[i];                                                   \
                     at.strides[i] = strides[i];                                               \
                 }                                                                             \
                 at.dtype_ = (AiterDtype)dtype;                                                \
                 at.device_id = device_id;                                                     \
                 return at;                                                                    \
             }),                                                                               \
             pybind11::arg("data_ptr"),                                                        \
             pybind11::arg("numel"),                                                           \
             pybind11::arg("ndim"),                                                            \
             pybind11::arg("shape"),                                                           \
             pybind11::arg("strides"),                                                         \
             pybind11::arg("dtype"),                                                           \
             pybind11::arg("device_id"))                                                       \
        .def_readwrite("numel_", &aiter_tensor_t::numel_)                                     \
        .def_readwrite("ndim", &aiter_tensor_t::ndim)                                         \
        .def_readwrite("device_id", &aiter_tensor_t::device_id);
#endif
// ---- end aiter_tensor_t / stream pybind infrastructure ----

Xiaowei.zhang's avatar
Xiaowei.zhang committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
#define ACTIVATION_PYBIND                                                                               \
      m.def("silu_and_mul", &aiter::silu_and_mul, "Activation function used in SwiGLU.",                \
            py::arg("out"), py::arg("input"));                                                          \
      m.def("scaled_silu_and_mul", &aiter::scaled_silu_and_mul, "Activation function used in scaled SwiGLU.",\
            py::arg("out"), py::arg("input"), py::arg("scale"));                                             \
      m.def("gelu_and_mul", &aiter::gelu_and_mul, "Activation function used in GELU.",                       \
            py::arg("out"), py::arg("input"));                                                               \
      m.def("gelu_tanh_and_mul", &aiter::gelu_tanh_and_mul, "Activation function used in GELU tanh.",        \
            py::arg("out"), py::arg("input"));

#define AITER_OPERATOR_PYBIND                                                   \
    m.def("add", &aiter_add, "apply for add with transpose and broadcast.");    \
    m.def("mul", &aiter_mul, "apply for mul with transpose and broadcast.");    \
    m.def("sub", &aiter_sub, "apply for sub with transpose and broadcast.");    \
    m.def("div", &aiter_div, "apply for div with transpose and broadcast.");    \
    m.def("add_", &aiter_add_, "apply for add_ with transpose and broadcast."); \
    m.def("mul_", &aiter_mul_, "apply for mul_ with transpose and broadcast."); \
    m.def("sub_", &aiter_sub_, "apply for sub_ with transpose and broadcast."); \
    m.def("div_", &aiter_div_, "apply for div_ with transpose and broadcast.");
#define AITER_UNARY_PYBIND                                  \
    m.def("sigmoid", &aiter_sigmoid, "apply for sigmoid."); \
    m.def("tanh", &aiter_tanh, "apply for tanh.");

#define ATTENTION_ASM_MLA_PYBIND                                                                  \
      m.def("mla_decode_stage1_asm_fwd", &mla_decode_stage1_asm_fwd, "mla_decode_stage1_asm_fwd", \
            py::arg("Q"),                                                                         \
            py::arg("KV"),                                                                        \
            py::arg("qo_indptr"),                                                                 \
            py::arg("kv_indptr"),                                                                 \
            py::arg("kv_page_indices"),                                                           \
            py::arg("kv_last_page_lens"),                                                         \
            py::arg("max_seqlen_q"),                                                              \
            py::arg("softmax_scale"),                                                             \
            py::arg("splitData"),                                                                 \
            py::arg("splitLse"));                                                                 \
      m.def("mla_prefill_asm_fwd", &mla_prefill_asm_fwd, "mla_prefill_asm_fwd",                   \
            py::arg("Q"),                                                                         \
            py::arg("KV"),                                                                        \
            py::arg("qo_indptr"),                                                                 \
            py::arg("kv_indptr"),                                                                 \
            py::arg("kv_page_indices"),                                                           \
            py::arg("kv_last_page_lens"),                                                         \
            py::arg("max_seqlen_q"),                                                              \
            py::arg("softmax_scale"),                                                             \
            py::arg("splitData"),                                                                 \
            py::arg("splitLse"));

#define ATTENTION_ASM_PYBIND                    \
      m.def("pa_fwd_asm", &pa_fwd, "pa_fwd",    \
            py::arg("Q"),                       \
            py::arg("K"),                       \
            py::arg("V"),                       \
            py::arg("block_tables"),            \
            py::arg("context_lens"),            \
            py::arg("max_num_blocks"),          \
            py::arg("K_QScale") = std::nullopt, \
            py::arg("V_QScale") = std::nullopt, \
            py::arg("out_") = std::nullopt,     \
            py::arg("high_precision") = 1);

#define ATTENTION_CK_PYBIND                                \
      m.def("pa_fwd_naive", &pa_fwd_naive, "pa_fwd_naive", \
            py::arg("Q"),                                  \
            py::arg("K"),                                  \
            py::arg("V"),                                  \
            py::arg("block_tables"),                       \
            py::arg("context_lens"),                       \
            py::arg("k_dequant_scales"),                   \
            py::arg("v_dequant_scales"),                   \
            py::arg("max_seq_len"),                        \
            py::arg("num_kv_heads"),                       \
            py::arg("scale_s"),                            \
            py::arg("scale_k"),                            \
            py::arg("scale_v"),                            \
            py::arg("block_size"),                         \
            py::arg("quant_algo"),                         \
            py::arg("out_") = std::nullopt);

#define ATTENTION_PYBIND                                            \
      m.def("paged_attention_rocm", &paged_attention,               \
            "paged_attention_rocm(Tensor! out, Tensor exp_sums,"    \
            "                Tensor max_logits, Tensor tmp_out,"    \
            "                Tensor query, Tensor key_cache,"       \
            "                Tensor value_cache, int num_kv_heads," \
            "                float scale, Tensor block_tables,"     \
            "                Tensor context_lens, int block_size,"  \
            "                int max_context_len,"                  \
            "                Tensor? alibi_slopes,"                 \
            "                str kv_cache_dtype,"                   \
            "                float k_scale, float v_scale) -> ()");

#define ATTENTION_RAGGED_PYBIND                                     \
      m.def("paged_attention_ragged", &paged_attention_ragged,      \
            "paged_attention_ragged(Tensor! out, Tensor exp_sums,"  \
            "                Tensor max_logits, Tensor tmp_out,"    \
            "                Tensor query, Tensor key_cache,"       \
            "                Tensor value_cache, int num_kv_heads," \
            "                float scale, Tensor block_tables,"     \
            "                Tensor context_lens, int block_size,"  \
            "                int max_context_len,"                  \
            "                Tensor? alibi_slopes,"                 \
            "                str kv_cache_dtype,"                   \
            "                float k_scale, float v_scale) -> ()");

#define BATCHED_GEMM_A8W8_PYBIND                                                                        \
      m.def("batched_gemm_a8w8", &batched_gemm_a8w8, "batched_gemm_a8w8", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"),                                     \
            py::arg("bias") = std::nullopt, py::arg("splitK") = 0);

#define BATCHED_GEMM_A8W8_TUNE_PYBIND                                                                                  \
      m.def("batched_gemm_a8w8_tune", &batched_gemm_a8w8_tune, "batched_gemm_a8w8_tune", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0,                           \
            py::arg("splitK") = 0);

#define CACHE_PYBIND                                                                         \
      m.def("swap_blocks", &swap_blocks,                                                     \
            "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");             \
      m.def("copy_blocks", &copy_blocks,                                                     \
            "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "               \
            "Tensor block_mapping) -> ()");                                                  \
                                                                                             \
      m.def("reshape_and_cache", &reshape_and_cache,                                         \
            "reshape_and_cache");                                                            \
      m.def("reshape_and_cache_flash", &reshape_and_cache_flash,                             \
            "reshape_and_cache_flash(Tensor key, Tensor value,"                              \
            "                        Tensor! key_cache,"                                     \
            "                        Tensor! value_cache,"                                   \
            "                        Tensor slot_mapping,"                                   \
            "                        str kv_cache_dtype,"                                    \
            "                        float k_scale, float v_scale) -> ()");                  \
      m.def("reshape_and_cache_with_pertoken_quant", &reshape_and_cache_with_pertoken_quant, \
            "reshape_and_cache_with_pertoken_quant(Tensor key, Tensor value,"                \
            "                        Tensor! key_cache,"                                     \
            "                        Tensor! value_cache,"                                   \
            "                        Tensor! k_dequant_scales,"                              \
            "                        Tensor! v_dequant_scales,"                              \
            "                        Tensor slot_mapping) -> ()");                           \
      m.def("reshape_and_cache_with_block_quant", &reshape_and_cache_with_block_quant,       \
            "reshape_and_cache_with_block_quant(Tensor key, Tensor value,"                   \
            "                        Tensor! key_cache,"                                     \
            "                        Tensor! value_cache,"                                   \
            "                        Tensor! k_dequant_scales,"                              \
            "                        Tensor! v_dequant_scales,"                              \
            "                        Tensor slot_mapping,"                                   \
            "                        const bool asm_layout) -> ()");                         \
      m.def("convert_fp8", &convert_fp8,                                                     \
            "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "                 \
            "str kv_cache_dtype) -> ()");

#define CUSTOM_ALL_REDUCE_PYBIND                                                               \
256
    AITER_TENSOR_PYBIND                                                                        \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
257
258
    m.def("init_custom_ar",                                                                    \
          &aiter::init_custom_ar,                                                              \
259
260
261
262
          py::arg("meta_ptr"),                                                                 \
          py::arg("rank_data_ptr"),                                                            \
          py::arg("rank_data_sz"),                                                             \
          py::arg("ipc_handle_ptrs"),                                                          \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
263
264
265
266
267
268
269
270
          py::arg("offsets"),                                                                  \
          py::arg("rank"),                                                                     \
          py::arg("fully_connected"));                                                         \
    m.def("all_reduce",                                                                        \
          &aiter::all_reduce,                                                                  \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("out"),                                                                      \
271
          py::arg("use_new"),                                                                  \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
272
          py::arg("open_fp8_quant"),                                                           \
273
274
275
276
277
278
279
280
281
          py::arg("reg_inp_ptr"),                                                              \
          py::arg("reg_inp_bytes"));                                                           \
    m.def("reduce_scatter",                                                                    \
          &aiter::reduce_scatter,                                                              \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("out"),                                                                      \
          py::arg("reg_ptr"),                                                                  \
          py::arg("reg_bytes"));                                                               \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
282
283
284
285
    m.def("all_gather_reg",                                                                    \
          &aiter::all_gather_reg,                                                              \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
286
287
          py::arg("out"),                                                                      \
          py::arg("dim"));                                                                     \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
288
289
290
291
292
    m.def("all_gather_unreg",                                                                  \
          &aiter::all_gather_unreg,                                                            \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("reg_buffer"),                                                               \
293
294
295
          py::arg("out"),                                                                      \
          py::arg("reg_bytes"),                                                                \
          py::arg("dim"));                                                                     \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
296
297
298
299
300
301
302
303
304
    m.def("fused_allreduce_rmsnorm",                                                           \
          &aiter::fused_allreduce_rmsnorm,                                                     \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("res_inp"),                                                                  \
          py::arg("res_out"),                                                                  \
          py::arg("out"),                                                                      \
          py::arg("w"),                                                                        \
          py::arg("eps"),                                                                      \
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
          py::arg("reg_ptr"),                                                                  \
          py::arg("reg_bytes"),                                                                \
          py::arg("use_1stage"));                                                              \
    m.def("fused_allreduce_rmsnorm_quant",                                                     \
          &aiter::fused_allreduce_rmsnorm_quant,                                               \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("res_inp"),                                                                  \
          py::arg("res_out"),                                                                  \
          py::arg("out"),                                                                      \
          py::arg("scale_out"),                                                                \
          py::arg("w"),                                                                        \
          py::arg("eps"),                                                                      \
          py::arg("reg_ptr"),                                                                  \
          py::arg("reg_bytes"),                                                                \
          py::arg("use_1stage"));                                                              \
    m.def("fused_allreduce_rmsnorm_quant_per_group",                                           \
          &aiter::fused_allreduce_rmsnorm_quant_per_group,                                     \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("res_inp"),                                                                  \
          py::arg("res_out"),                                                                  \
          py::arg("out"),                                                                      \
          py::arg("scale_out"),                                                                \
          py::arg("w"),                                                                        \
          py::arg("eps"),                                                                      \
          py::arg("group_size"),                                                               \
          py::arg("reg_ptr"),                                                                  \
          py::arg("reg_bytes"),                                                                \
          py::arg("use_1stage"),                                                               \
          py::arg("bf16_out_ptr") = static_cast<int64_t>(0));                                  \
    m.def("fused_qknorm_allreduce",                                                            \
          &aiter::fused_qknorm_allreduce,                                                      \
          py::arg("_fa"),                                                                      \
          py::arg("qkv_in"),                                                                   \
          py::arg("q_w"),                                                                      \
          py::arg("k_w"),                                                                      \
          py::arg("q_out"),                                                                    \
          py::arg("k_out"),                                                                    \
          py::arg("v_out"),                                                                    \
          py::arg("eps"),                                                                      \
          py::arg("reg_ptr"),                                                                  \
          py::arg("reg_bytes"));                                                               \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
348
349
    m.def("dispose", &aiter::dispose, py::arg("_fa"));                                         \
    m.def("meta_size", &aiter::meta_size);                                                     \
350
351
352
353
354
355
356
357
    m.def("register_input_buffer",                                                             \
          &aiter::register_input_buffer,                                                       \
          py::arg("_fa"),                                                                      \
          py::arg("self_ptr"),                                                                 \
          py::arg("ipc_handle_ptrs"),                                                          \
          py::arg("offsets"));                                                                 \
    m.def("register_output_buffer",                                                            \
          &aiter::register_output_buffer,                                                      \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
358
          py::arg("_fa"),                                                                      \
359
360
          py::arg("self_ptr"),                                                                 \
          py::arg("ipc_handle_ptrs"),                                                          \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
361
          py::arg("offsets"));                                                                 \
362
363
364
365
366
367
    m.def("get_graph_buffer_count", &aiter::get_graph_buffer_count, py::arg("_fa"));           \
    m.def("get_graph_buffer_ipc_meta",                                                         \
          &aiter::get_graph_buffer_ipc_meta,                                                   \
          py::arg("_fa"),                                                                      \
          py::arg("handle_out"),                                                               \
          py::arg("offset_out"));                                                              \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
368
369
370
    m.def("register_graph_buffers",                                                            \
          &aiter::register_graph_buffers,                                                      \
          py::arg("_fa"),                                                                      \
371
372
          py::arg("handle_ptrs"),                                                              \
          py::arg("offset_ptrs"));                                                             \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
373
    m.def("allocate_meta_buffer", &aiter::allocate_meta_buffer, py::arg("size"));              \
374
375
376
377
378
    m.def("free_meta_buffer", &aiter::free_meta_buffer, py::arg("ptr"));                       \
    m.def("get_meta_buffer_ipc_handle",                                                        \
          &aiter::get_meta_buffer_ipc_handle,                                                  \
          py::arg("inp_ptr"),                                                                  \
          py::arg("out_handle_ptr"));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998



#define CUSTOM_PYBIND                                                                                 \
      m.def("wvSpltK", &wvSpltK, "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"         \
                                 "        int CuCount) -> ()");                                       \
      m.def("LLMM1", &LLMM1, "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " \
                             "()");

#define GEMM_A8W8_ASM_PYBIND                                              \
      m.def("gemm_a8w8_asm", &gemm_a8w8_asm,                              \
            "Asm gemm a8w8 ,  weight should be shuffle to layout(32,16)", \
            py::arg("XQ"), py::arg("WQ"),                                 \
            py::arg("x_scale"), py::arg("w_scale"),                       \
            py::arg("Out"), py::arg("bias"),                              \
            py::arg("sub_m") = 128, py::arg("sub_n") = 128,               \
            py::arg("pad_a") = 0, py::arg("pad_b") = 0,                   \
            py::arg("pad_c") = 0, py::arg("splitK") = 0);

#define GEMM_A8W8_BLOCKSCALE_PYBIND                                                                             \
      m.def("gemm_a8w8_blockscale", &gemm_a8w8_blockscale, "fp8 blockscale gemm", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"));

#define GEMM_A8W8_BLOCKSCALE_TUNE_PYBIND                                                                                        \
      m.def("gemm_a8w8_blockscale_tune", &gemm_a8w8_blockscale_tune, "gemm_a8w8_blockscale_tune", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0,                                    \
            py::arg("splitK") = 0);

#define GEMM_A8W8_PYBIND                                                        \
      m.def("gemm_a8w8", &gemm_a8w8, "gemm_a8w8", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"),             \
            py::arg("bias") = std::nullopt, py::arg("splitK") = 0);

#define GEMM_A8W8_TUNE_PYBIND                                                                  \
      m.def("gemm_a8w8_tune", &gemm_a8w8_tune, "gemm_a8w8_tune", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0,   \
            py::arg("splitK") = 0);

#define MHA_BWD_ASM_PYBIND                                  \
      m.def("fmha_v3_bwd", &aiter::torch_itfs::fmha_v3_bwd, \
            py::arg("dout"),                                \
            py::arg("q"), py::arg("k"), py::arg("v"),       \
            py::arg("out"),                                 \
            py::arg("softmax_lse"),                         \
            py::arg("dropout_p"),                           \
            py::arg("softmax_scale"),                       \
            py::arg("is_causal"),                           \
            py::arg("window_size_left"),                    \
            py::arg("window_size_right"),                   \
            py::arg("deterministic"),                       \
            py::arg("is_v3_atomic_fp32"),                   \
            py::arg("how_v3_bf16_cvt"),                     \
            py::arg("dq") = std::nullopt,                   \
            py::arg("dk") = std::nullopt,                   \
            py::arg("dv") = std::nullopt,                   \
            py::arg("alibi_slopes") = std::nullopt,         \
            py::arg("rng_state") = std::nullopt,            \
            py::arg("gen") = std::nullopt);

#define MHA_VARLEN_BWD_ASM_PYBIND                                         \
      m.def("fmha_v3_varlen_bwd", &aiter::torch_itfs::fmha_v3_varlen_bwd, \
            py::arg("dout"),                                              \
            py::arg("q"), py::arg("k"), py::arg("v"),                     \
            py::arg("out"),                                               \
            py::arg("softmax_lse"),                                       \
            py::arg("cu_seqlens_q"),                                      \
            py::arg("cu_seqlens_k"),                                      \
            py::arg("max_seqlen_q"),                                      \
            py::arg("max_seqlen_k"),                                      \
            py::arg("dropout_p"),                                         \
            py::arg("softmax_scale"),                                     \
            py::arg("zero_tensors"),                                      \
            py::arg("is_causal"),                                         \
            py::arg("window_size_left"),                                  \
            py::arg("window_size_right"),                                 \
            py::arg("deterministic"),                                     \
            py::arg("is_v3_atomic_fp32"),                                 \
            py::arg("how_v3_bf16_cvt"),                                   \
            py::arg("dq") = std::nullopt,                                 \
            py::arg("dk") = std::nullopt,                                 \
            py::arg("dv") = std::nullopt,                                 \
            py::arg("alibi_slopes") = std::nullopt,                       \
            py::arg("rng_state") = std::nullopt,                          \
            py::arg("gen") = std::nullopt);

#define MHA_BWD_PYBIND                                \
      m.def("mha_bwd", &aiter::torch_itfs::mha_bwd,   \
            py::arg("dout"),                          \
            py::arg("q"), py::arg("k"), py::arg("v"), \
            py::arg("out"),                           \
            py::arg("softmax_lse"),                   \
            py::arg("dropout_p"),                     \
            py::arg("softmax_scale"),                 \
            py::arg("is_causal"),                     \
            py::arg("window_size_left"),              \
            py::arg("window_size_right"),             \
            py::arg("deterministic"),                 \
            py::arg("dq") = std::nullopt,             \
            py::arg("dk") = std::nullopt,             \
            py::arg("dv") = std::nullopt,             \
            py::arg("dbias") = std::nullopt,          \
            py::arg("bias") = std::nullopt,           \
            py::arg("alibi_slopes") = std::nullopt,   \
            py::arg("rng_state") = std::nullopt,      \
            py::arg("gen") = std::nullopt);

#define MHA_FWD_ASM_PYBIND                                  \
      m.def("fmha_v3_fwd", &aiter::torch_itfs::fmha_v3_fwd, \
            py::arg("q"), py::arg("k"), py::arg("v"),       \
            py::arg("dropout_p"),                           \
            py::arg("softmax_scale"),                       \
            py::arg("is_causal"),                           \
            py::arg("window_size_left"),                    \
            py::arg("window_size_right"),                   \
            py::arg("return_softmax_lse"),                  \
            py::arg("return_dropout_randval"),              \
            py::arg("out") = std::nullopt,                  \
            py::arg("bias") = std::nullopt,                 \
            py::arg("alibi_slopes") = std::nullopt,         \
            py::arg("gen") = std::nullopt);

#define MHA_FWD_PYBIND                                \
      m.def("mha_fwd", &aiter::torch_itfs::mha_fwd,   \
            py::arg("q"), py::arg("k"), py::arg("v"), \
            py::arg("dropout_p"),                     \
            py::arg("softmax_scale"),                 \
            py::arg("is_causal"),                     \
            py::arg("window_size_left"),              \
            py::arg("window_size_right"),             \
            py::arg("return_softmax_lse"),            \
            py::arg("return_dropout_randval"),        \
            py::arg("out") = std::nullopt,            \
            py::arg("bias") = std::nullopt,           \
            py::arg("alibi_slopes") = std::nullopt,   \
            py::arg("gen") = std::nullopt);

#define MHA_VARLEN_BWD_PYBIND                                     \
      m.def("mha_varlen_bwd", &aiter::torch_itfs::mha_varlen_bwd, \
            py::arg("dout"),                                      \
            py::arg("q"), py::arg("k"), py::arg("v"),             \
            py::arg("out"),                                       \
            py::arg("softmax_lse"),                               \
            py::arg("cu_seqlens_q"),                              \
            py::arg("cu_seqlens_k"),                              \
            py::arg("max_seqlen_q"),                              \
            py::arg("max_seqlen_k"),                              \
            py::arg("dropout_p"),                                 \
            py::arg("softmax_scale"),                             \
            py::arg("zero_tensors"),                              \
            py::arg("is_causal"),                                 \
            py::arg("window_size_left"),                          \
            py::arg("window_size_right"),                         \
            py::arg("deterministic"),                             \
            py::arg("dq") = std::nullopt,                         \
            py::arg("dk") = std::nullopt,                         \
            py::arg("dv") = std::nullopt,                         \
            py::arg("alibi_slopes") = std::nullopt,               \
            py::arg("rng_state") = std::nullopt,                  \
            py::arg("gen") = std::nullopt);

#define MHA_VARLEN_FWD_PYBIND                                     \
      m.def("mha_varlen_fwd", &aiter::torch_itfs::mha_varlen_fwd, \
            py::arg("q"), py::arg("k"), py::arg("v"),             \
            py::arg("cu_seqlens_q"),                              \
            py::arg("cu_seqlens_k"),                              \
            py::arg("max_seqlen_q"),                              \
            py::arg("max_seqlen_k"),                              \
            py::arg("dropout_p"),                                 \
            py::arg("softmax_scale"),                             \
            py::arg("logits_soft_cap"),                           \
            py::arg("zero_tensors"),                              \
            py::arg("is_causal"),                                 \
            py::arg("window_size_left"),                          \
            py::arg("window_size_right"),                         \
            py::arg("return_softmax_lse"),                        \
            py::arg("return_dropout_randval"),                    \
            py::arg("out") = std::nullopt,                        \
            py::arg("block_table") = std::nullopt,                \
            py::arg("bias") = std::nullopt,                       \
            py::arg("alibi_slopes") = std::nullopt,               \
            py::arg("gen") = std::nullopt);

#define MHA_BATCH_PREFILL_PYBIND                      \
      m.def("mha_batch_prefill", &aiter::torch_itfs::mha_batch_prefill,  \
            py::arg("q"), py::arg("k"), py::arg("v"),                    \
            py::arg("cu_seqlens_q"),                                     \
            py::arg("kv_indptr"),                                        \
            py::arg("kv_page_indices"),                                  \
            py::arg("max_seqlen_q"),                                     \
            py::arg("max_seqlen_k"),                                     \
            py::arg("dropout_p"),                                        \
            py::arg("softmax_scale"),                                    \
            py::arg("logits_soft_cap"),                                  \
            py::arg("zero_tensors"),                                     \
            py::arg("is_causal"),                                        \
            py::arg("window_size_left"),                                 \
            py::arg("window_size_right"),                                \
            py::arg("return_softmax_lse"),                               \
            py::arg("return_dropout_randval"),                           \
            py::arg("out") = std::nullopt,                               \
            py::arg("bias") = std::nullopt,                              \
            py::arg("alibi_slopes") = std::nullopt,                      \
            py::arg("gen") = std::nullopt);

#define MOE_CK_2STAGES_PYBIND                          \
      m.def("ck_moe_stage1", &ck_moe_stage1,           \
            py::arg("hidden_states"),                  \
            py::arg("w1"),                             \
            py::arg("w2"),                             \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("out"),                            \
            py::arg("topk"),                           \
            py::arg("w1_scale") = std::nullopt,        \
            py::arg("a1_scale") = std::nullopt,        \
            py::arg("block_m") = 32,                   \
            py::arg("sorted_weights") = std::nullopt,  \
            py::arg("act_op") = 0);                    \
                                                       \
      m.def("ck_moe_stage2", &ck_moe_stage2,           \
            py::arg("inter_states"),                   \
            py::arg("w1"),                             \
            py::arg("w2"),                             \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("out"),                            \
            py::arg("topk"),                           \
            py::arg("w2_scale") = std::nullopt,        \
            py::arg("a2_scale") = std::nullopt,        \
            py::arg("block_m") = 32,                   \
            py::arg("sorted_weights") = std::nullopt); \

#define MOE_ASM_2STAGES_PYBIND                         \
      m.def("asm_fmoe_stage1", &asm_fmoe_stage1,       \
            py::arg("out"),                            \
            py::arg("input"),                          \
            py::arg("gate"),                           \
            py::arg("down"),                           \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_weights"),                 \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("top_k"),                          \
            py::arg("scale_a") = std::nullopt,         \
            py::arg("scale_b") = std::nullopt,         \
            py::arg("zero_points") = std::nullopt,     \
            py::arg("mode") = 0,                       \
            py::arg("solidx") = 0,                     \
            py::arg("block_size") = 16,                \
            py::arg("persist_groups") = 0);            \
                                                       \
      m.def("asm_fmoe_stage2", &asm_fmoe_stage2,       \
            py::arg("out"),                            \
            py::arg("input"),                          \
            py::arg("gate"),                           \
            py::arg("down"),                           \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_weights"),                 \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("top_k"),                          \
            py::arg("scale_a") = std::nullopt,         \
            py::arg("scale_b") = std::nullopt,         \
            py::arg("zero_points") = std::nullopt,     \
            py::arg("mode") = 0,                       \
            py::arg("solidx") = 0,                     \
            py::arg("block_size") = 16,                \
            py::arg("persist_groups") = 0);            \
                                                       \
      m.def("asm_fmoe_a8", &asm_fmoe_a8,               \
            py::arg("out"),                            \
            py::arg("input"),                          \
            py::arg("gate"),                           \
            py::arg("down"),                           \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_weights"),                 \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("top_k"),                          \
            py::arg("scale_a") = std::nullopt,         \
            py::arg("scale_b") = std::nullopt,         \
            py::arg("zero_points") = std::nullopt,     \
            py::arg("mode") = 0,                       \
            py::arg("solidx") = 0,                     \
            py::arg("out_type") = 0,                   \
            py::arg("persist_groups") = 0,             \
            py::arg("use_shuffle") = 0);               \
                                                       \
      m.def("asm_moe_get_solutions", &asm_moe_get_solutions,  \
            py::arg("hidden_states"),                         \
            py::arg("w1"),                                    \
            py::arg("w2"),                                    \
            py::arg("topk_weights"),                          \
            py::arg("topk_ids"),                              \
            py::arg("use_int8_w8a16") = false,                \
            py::arg("use_int4_w4a16") = false,                \
            py::arg("use_int8_w8a8") = false,                 \
            py::arg("use_int4_w4a8") = false,                 \
            py::arg("use_fp8_w8a8") = false,                  \
            py::arg("per_channel_quant") = false,             \
            py::arg("w1_zp") = std::nullopt,                  \
            py::arg("w2_zp") = std::nullopt,                  \
            py::arg("w1_scale") = std::nullopt,               \
            py::arg("w2_scale") = std::nullopt,               \
            py::arg("a1_scale") = std::nullopt,               \
            py::arg("a2_scale") = std::nullopt,               \
            py::arg("block_shape_n") = 0,                     \
            py::arg("block_shape_k") = 0,                     \
            py::arg("block_m") = 32,                          \
            py::arg("expert_mask") = std::nullopt);           \

#define AWQ_GEMM_ASM_PYBIND                          \
      m.def("awq_gemm_asm", &awq_gemm_asm,           \
            py::arg("out"),                          \
            py::arg("mat1"),                         \
            py::arg("mat2"),                         \
            py::arg("zero") = std::nullopt,          \
            py::arg("scalar") = std::nullopt );      \
      m.def("awq_gemm_asm_tuning", &awq_gemm_asm_tuning,        \
            py::arg("out"),                                     \
            py::arg("mat1"),                                    \
            py::arg("mat2"),                                    \
            py::arg("zero") = std::nullopt,                     \
            py::arg("scalar") = std::nullopt,                   \
            py::arg("solidx") = 0,                              \
            py::arg("jsonfile") = std::nullopt );                \

#define AWQ_DQ_ASM_PYBIND                          \
      m.def("awq_dq_asm", &awq_dq_asm,           \
            py::arg("out"),                  \
            py::arg("mat1"),                  \
            py::arg("zero") = std::nullopt,                  \
            py::arg("scalar") = std::nullopt );                            \


#define MOE_CK_PYBIND                                                               \
      m.def("ck_moe", &ck_moe,                                                      \
            py::arg("hidden_states"), py::arg("w1"), py::arg("w2"),                 \
            py::arg("topk_weights"), py::arg("topk_ids"),                           \
            py::arg("use_int8_w8a16") = false,                                      \
            py::arg("use_int4_w4a16") = false,                                      \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_int4_w4a8_block") = false,                                 \
            py::arg("w1_zp") = std::nullopt, py::arg("w2_zp") = std::nullopt,       \
            py::arg("w1_scale") = std::nullopt, py::arg("w2_scale") = std::nullopt, \
            py::arg("a1_scale") = std::nullopt, py::arg("a2_scale") = std::nullopt, \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("solution_id") = 0,                                             \
            py::arg("expert_mask") = std::nullopt);                                 \
      m.def("ck_shuffle_moe", &ck_shuffle_moe,                                      \
            py::arg("hidden_states"), py::arg("w1"), py::arg("w2"),                 \
            py::arg("topk_weights"), py::arg("topk_ids"),                           \
            py::arg("use_int8_w8a16") = false,                                      \
            py::arg("use_int4_w4a16") = false,                                      \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_int4_w4a8_block") = false,                                 \
            py::arg("w1_zp") = std::nullopt, py::arg("w2_zp") = std::nullopt,       \
            py::arg("w1_scale") = std::nullopt, py::arg("w2_scale") = std::nullopt, \
            py::arg("a1_scale") = std::nullopt, py::arg("a2_scale") = std::nullopt, \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("solution_id") = 0,                                             \
            py::arg("expert_mask") = std::nullopt);                                 \
      m.def("ck_moe_get_solutions", &ck_moe_get_solutions,                          \
            py::arg("hidden_states"), py::arg("w1"), py::arg("w2"),                 \
            py::arg("topk_weights"), py::arg("topk_ids"),                           \
            py::arg("use_int8_w8a16") = false,                                      \
            py::arg("use_int4_w4a16") = false,                                      \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_int4_w4a8_block") = false,                                 \
            py::arg("w1_zp") = std::nullopt, py::arg("w2_zp") = std::nullopt,       \
            py::arg("w1_scale") = std::nullopt, py::arg("w2_scale") = std::nullopt, \
            py::arg("a1_scale") = std::nullopt, py::arg("a2_scale") = std::nullopt, \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("expert_mask") = std::nullopt);                                 \
      m.def("ck_moe_stage_1", &ck_moe_stage_1,                                      \
            py::arg("hidden_states"),                                               \
            py::arg("w1"),                                                          \
            py::arg("w2"),                                                          \
            py::arg("sorted_token_ids"),                                            \
            py::arg("sorted_expert_ids"),                                           \
            py::arg("tokens_positions_per_expert"),                                 \
            py::arg("num_valid_ids"),                                               \
            py::arg("out"),                                                         \
            py::arg("topk"),                                                        \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_fp8_w8a8_block") = false,                                  \
            py::arg("w1_scale") = std::nullopt,                                     \
            py::arg("a1_scale") = std::nullopt,                                     \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("sorted_weights") = std::nullopt,                               \
            py::arg("act_op") = 0);                                                 \
      m.def("ck_moe_stage_2", &ck_moe_stage_2,                                      \
            py::arg("inter_states"),                                                \
            py::arg("w1"),                                                          \
            py::arg("w2"),                                                          \
            py::arg("sorted_token_ids"),                                            \
            py::arg("sorted_expert_ids"),                                           \
            py::arg("tokens_positions_per_expert"),                                 \
            py::arg("num_valid_ids"),                                               \
            py::arg("out"),                                                         \
            py::arg("topk"),                                                        \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_fp8_w8a8_block") = false,                                  \
            py::arg("w2_scale") = std::nullopt,                                     \
            py::arg("a2_scale") = std::nullopt,                                     \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("sorted_weights") = std::nullopt);                              \
      m.def("ck_moe_per_token_quant", &ck_moe_per_token_quant,                      \
            py::arg("input"),                                                       \
            py::arg("out_quant"),                                                   \
            py::arg("out_scale"));                                                  \
            
#define MOE_UTILS_PYBIND                                                         \
      m.def("topk_softmax",                                                      \
            &aiter::topk_softmax,                                                \
            py::arg("topk_weights"),                                             \
            py::arg("topk_indices"),                                             \
            py::arg("token_expert_indices"),                                     \
            py::arg("gating_output"),                                            \
            py::arg("need_renorm"),                                              \
            "Apply topk softmax to the gating outputs.");                        \
      m.def("grouped_topk",                                                      \
            &grouped_topk,                                                       \
            py::arg("gating_output"),                                            \
            py::arg("topk_weights"),                                             \
            py::arg("topk_ids"),                                                 \
            py::arg("num_expert_group"),                                         \
            py::arg("topk_grp"),                                                 \
            py::arg("need_renorm"),                                              \
            py::arg("is_softmax")            = true,                             \
            py::arg("routed_scaling_factor") = 1.0f,                             \
            "Apply grouped topk softmax/sigmodd to the gating outputs.");        \
      m.def("biased_grouped_topk",                                               \
            &biased_grouped_topk,                                                \
            py::arg("gating_output"),                                            \
            py::arg("correction_bias"),                                          \
            py::arg("topk_weights"),                                             \
            py::arg("topk_ids"),                                                 \
            py::arg("num_expert_group"),                                         \
            py::arg("topk_grp"),                                                 \
            py::arg("need_renorm"),                                              \
            py::arg("routed_scaling_factor") = 1.0f,                             \
            "Apply biased grouped topk softmax to the gating outputs.");         \
      m.def("moe_sum", &aiter::moe_sum, "moe_sum(Tensor! input, Tensor output) -> ()"); \
      m.def("moe_fused_gate",                                                    \
            &moe_fused_gate,                                                     \
            py::arg("input"),                                                    \
            py::arg("bias"),                                                     \
            py::arg("topk_weights"),                                             \
            py::arg("topk_ids"),                                                 \
            py::arg("num_expert_group"),                                         \
            py::arg("topk_group"),                                               \
            py::arg("topk"),                                                     \
            py::arg("num_fused_shared_experts"),                                 \
            py::arg("routed_scaling_factor") = 1.0,                              \
            "Apply biased grouped topk softmax to the gating outputs.");         \
      m.def("moe_align_block_size", &moe_align_block_size,                       \
            "moe_align_block_size(Tensor topk_ids, int num_experts,"             \
            "                     int block_size, Tensor! sorted_token_ids,"     \
            "                     Tensor! experts_ids,"                          \
            "                     Tensor! num_tokens_post_pad) -> ()");          \
      m.def("sgl_moe_align_block_size", &sgl_moe_align_block_size,               \
            "sgl_moe_align_block_size(Tensor topk_ids, int num_experts,"         \
            "                         int block_size, Tensor! sorted_token_ids," \
            "                         Tensor! experts_ids,"                      \
            "                         Tensor! num_tokens_post_pad) -> ()");      \


#define MOE_OP_PYBIND                                                            \
      m.def("fmoe", &fmoe);                                                      \
      m.def("fmoe_int8_g1u0", &fmoe_int8_g1u0,                                   \
            py::arg("out"), py::arg("input"),                                    \
            py::arg("gate"), py::arg("down"),                                    \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),              \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("topk"), py::arg("input_scale"),                             \
            py::arg("fc1_scale"), py::arg("fc2_scale"),                          \
            py::arg("fc2_smooth_scale") = std::nullopt,                          \
            py::arg("activation") = ActivationType::Silu);                       \
      m.def("fmoe_g1u1", &fmoe_g1u1,                                             \
            py::arg("out"), py::arg("input"),                                    \
            py::arg("gate"), py::arg("down"),                                    \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),              \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("topk"), py::arg("input_scale"),                             \
            py::arg("fc1_scale"), py::arg("fc2_scale"),                          \
            py::arg("fc2_smooth_scale") = std::nullopt,                          \
            py::arg("activation") = ActivationType::Silu);                       \
      m.def("fmoe_g1u1_tkw1", &fmoe_g1u1_tkw1,                                   \
            py::arg("out"), py::arg("input"),                                    \
            py::arg("gate"), py::arg("down"),                                    \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),              \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("topk"), py::arg("input_scale"),                             \
            py::arg("fc1_scale"), py::arg("fc2_scale"),                          \
            py::arg("fc2_smooth_scale") = std::nullopt,                          \
            py::arg("activation") = ActivationType::Silu);                       \
      m.def("fmoe_int8_g1u0_a16", &fmoe_int8_g1u0_a16);                          \
      m.def("fmoe_g1u1_a16", &fmoe_g1u1_a16);                                    \
      m.def("fmoe_fp8_blockscale_g1u1", &fmoe_fp8_blockscale_g1u1,               \
            py::arg("out"), py::arg("input"),                                    \
            py::arg("gate"), py::arg("down"),                                    \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),              \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("topk"),                                                     \
            py::arg("input_scale"),                                              \
            py::arg("fc1_scale"), py::arg("fc2_scale"),                          \
            py::arg("fc_scale_blkn") = 128, py::arg("fc_scale_blkk") = 128,      \
            py::arg("fc2_smooth_scale") = std::nullopt,                          \
            py::arg("activation") = ActivationType::Silu);                       \
      m.def("moe_stage1_g1u1", &moe_stage1_g1u1,                                 \
            py::arg("input"),                                                    \
            py::arg("w1"), py::arg("w2"),                                        \
            py::arg("sorted_token_ids"),                                         \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("out"),                                                      \
            py::arg("inter_dim"),                                                \
            py::arg("kernelName"),                                               \
            py::arg("block_m"),                                                  \
            py::arg("ksplit") = 0,                                               \
            py::arg("activation") = ActivationType::Silu,                        \
            py::arg("quant_type") = QuantType::No,                               \
            py::arg("a1_scale") = std::nullopt,                                  \
            py::arg("w1_scale") = std::nullopt,                                  \
            py::arg("sorted_weights") = std::nullopt);                           \

#define MOE_SUM_PYBIND                                              \
      m.def("asm_moe_sum", &asm_moe_sum, "asm_moe_sum(Tensor! input, Tensor output, Tensor sorted_ids) -> ()"); \

#define MOE_SORTING_PYBIND                                                          \
      m.def("moe_sorting_fwd", &moe_sorting_fwd,                                    \
            py::arg("topk_ids"), py::arg("topk_weights"),                           \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),                 \
            py::arg("sorted_expert_ids"), py::arg("tokens_positions_per_expert"),   \
            py::arg("num_valid_ids"), py::arg("moe_buf"), py::arg("num_experts"),   \
            py::arg("unit_size"), py::arg("local_expert_mask") = std::nullopt);

#define NORM_PYBIND                                                                      \
      m.def("layernorm2d_fwd", &layernorm2d,                                             \
            py::arg("input"), py::arg("weight"), py::arg("bias"),                        \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_add", &layernorm2d_with_add,                           \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("residual_in"), py::arg("residual_out"),                             \
            py::arg("weight"), py::arg("bias"),                                          \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_smoothquant", &layernorm2d_with_smoothquant,           \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("xscale"), py::arg("yscale"),                                        \
            py::arg("weight"), py::arg("bias"),                                          \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_add_smoothquant", &layernorm2d_with_add_smoothquant,   \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("residual_in"), py::arg("residual_out"),                             \
            py::arg("xscale"), py::arg("yscale"),                                        \
            py::arg("weight"), py::arg("bias"),                                          \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_dynamicquant", &layernorm2d_with_dynamicquant,         \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("yscale"), py::arg("weight"), py::arg("bias"),                       \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_add_dynamicquant", &layernorm2d_with_add_dynamicquant, \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("residual_in"), py::arg("residual_out"),                             \
            py::arg("yscale"), py::arg("weight"), py::arg("bias"),                       \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);
      // m.def("layernorm2d_with_add_asm", &layernorm2d_with_add_asm);                      \
      // m.def("layernorm2d_with_add_smoothquant_asm", &layernorm2d_with_add_smoothquant_asm);

#define POS_ENCODING_PYBIND                                                 \
      m.def("rotary_embedding_fwd", &rotary_embedding, "rotary_embedding"); \
      m.def("batched_rotary_embedding", &batched_rotary_embedding, "batched_rotary_embedding");

#define QUANT_PYBIND                                                     \
    m.def("static_per_tensor_quant", &aiter::static_per_tensor_quant);   \
    m.def("dynamic_per_tensor_quant", &aiter::dynamic_per_tensor_quant); \
    m.def("dynamic_per_token_scaled_quant",                              \
          &aiter::dynamic_per_token_scaled_quant,                        \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("scales"),                                             \
          py::arg("scale_ub")        = std::nullopt,                     \
          py::arg("shuffle_scale")   = false,                            \
          py::arg("num_rows")        = std::nullopt,                     \
          py::arg("num_rows_factor") = 1);                               \
    m.def("dynamic_per_group_scaled_quant_fp4",                          \
          &aiter::dynamic_per_group_scaled_quant_fp4,                    \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("scales"),                                             \
          py::arg("group_size")      = 32,                               \
          py::arg("shuffle_scale")   = true,                             \
          py::arg("num_rows")        = std::nullopt,                     \
          py::arg("num_rows_factor") = 1);                               \
    m.def("smooth_per_token_scaled_quant",                               \
          &aiter::smooth_per_token_scaled_quant,                         \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("scales"),                                             \
          py::arg("smooth_scale"),                                       \
          py::arg("smooth_scale_map") = std::nullopt,                    \
          py::arg("shuffle_scale")    = false,                           \
          py::arg("num_rows")         = std::nullopt,                    \
          py::arg("num_rows_factor")  = 1);                               \
    m.def("partial_transpose",                                           \
          &aiter::partial_transpose,                                     \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
          py::arg("num_rows"));                                          \
    m.def("moe_swiglu_dynamic_quant",                                     \
          &aiter::moe_swiglu_dynamic_quant,                               \
          py::arg("scatter_tokens"),                                      \
          py::arg("smooth"),                                              \
          py::arg("experts_tokens_count"),                                \
          py::arg("experts_tokens_start"),                                \
          py::arg("output"),                                              \
          py::arg("scales"),                                              \
          py::arg("beta") = 1.0f);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062

#define RMSNORM_PYBIND                                                                             \
    m.def("rms_norm_cu",                                                                           \
          &rms_norm,                                                                               \
          "Apply Root Mean Square (RMS) Normalization to the input tensor.");                      \
    m.def(                                                                                         \
        "fused_add_rms_norm_cu", &fused_add_rms_norm, "In-place fused Add and RMS Normalization"); \
    m.def("rmsnorm2d_fwd",                                                                         \
          &rmsnorm2d,                                                                              \
          py::arg("input"),                                                                        \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"));                                             \
    m.def("rmsnorm2d_fwd_with_add",                                                                \
          &rmsnorm2d_with_add,                                                                     \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("residual_in"),                                                                  \
          py::arg("residual_out"),                                                                 \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"));                                             \
    m.def("rmsnorm2d_fwd_with_smoothquant",                                                        \
          &rmsnorm2d_with_smoothquant,                                                             \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("xscale"),                                                                       \
          py::arg("yscale"),                                                                       \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"));                                             \
    m.def("rmsnorm2d_fwd_with_add_smoothquant",                                                    \
          &rmsnorm2d_with_add_smoothquant,                                                         \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("residual_in"),                                                                  \
          py::arg("residual_out"),                                                                 \
          py::arg("xscale"),                                                                       \
          py::arg("yscale"),                                                                       \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"),                                                                      \
          py::arg("out_before_quant")            = std::nullopt);                                             \
    m.def("rmsnorm2d_fwd_with_dynamicquant",                                                       \
          &rmsnorm2d_with_dynamicquant,                                                            \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("yscale"),                                                                       \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"));                                             \
    m.def("rmsnorm2d_fwd_with_add_dynamicquant",                                                   \
          &rmsnorm2d_with_add_dynamicquant,                                                        \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("residual_in"),                                                                  \
          py::arg("residual_out"),                                                                 \
          py::arg("yscale"),                                                                       \
          py::arg("weight"),                                                                       \
1063
1064
1065
1066
1067
1068
1069
          py::arg("epsilon"));                                                                     \
    m.def("head_rms_norm",                                                                         \
          &head_rms_norm,                                                                          \
          py::arg("input"),                                                                        \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"),                                                                      \
          py::arg("norm_head_dim"));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
#define ROPE_GENERAL_FWD_PYBIND                                   \
      m.def("rope_fwd_impl", &rope_fwd_impl);                     \
      m.def("rope_2c_fwd_impl", &rope_2c_fwd_impl);               \
      m.def("rope_cached_fwd_impl", &rope_cached_fwd_impl);       \
      m.def("rope_cached_2c_fwd_impl", &rope_cached_2c_fwd_impl); \
      m.def("rope_thd_fwd_impl", &rope_thd_fwd_impl);             \
      m.def("rope_2d_fwd_impl", &rope_2d_fwd_impl);

#define ROPE_GENERAL_BWD_PYBIND                                   \
      m.def("rope_bwd_impl", &rope_bwd_impl);                     \
      m.def("rope_2c_bwd_impl", &rope_2c_bwd_impl);               \
      m.def("rope_cached_bwd_impl", &rope_cached_bwd_impl);       \
      m.def("rope_cached_2c_bwd_impl", &rope_cached_2c_bwd_impl); \
      m.def("rope_thd_bwd_impl", &rope_thd_bwd_impl);             \
      m.def("rope_2d_bwd_impl", &rope_2d_bwd_impl);

#define ROPE_POS_FWD_PYBIND                                                                     \
      m.def("rope_cached_positions_fwd_impl", &rope_cached_positions_fwd_impl);                 \
      m.def("rope_cached_positions_2c_fwd_impl", &rope_cached_positions_2c_fwd_impl);           \
      m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl); \
      m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl);

#define FUSED_QKNORM_MROPE_CACHE_QUANT_PYBIND               \
    m.def("fused_qk_norm_mrope_3d_cache_pts_quant_shuffle", \
          &fused_qk_norm_mrope_3d_cache_pts_quant_shuffle,  \
          py::arg("qkv"),                                   \
          py::arg("qw"),                                    \
          py::arg("kw"),                                    \
          py::arg("cos_sin"),                               \
          py::arg("positions"),                             \
          py::arg("num_tokens"),                            \
          py::arg("num_heads_q"),                           \
          py::arg("num_heads_k"),                           \
          py::arg("num_heads_v"),                           \
          py::arg("head_size"),                             \
          py::arg("is_neox_style"),                         \
          py::arg("mrope_section_"),                        \
          py::arg("is_interleaved"),                        \
          py::arg("eps"),                                   \
          py::arg("q_out"),                                 \
          py::arg("k_cache"),                               \
          py::arg("v_cache"),                               \
          py::arg("slot_mapping"),                          \
          py::arg("per_tensor_k_scale"),                    \
          py::arg("per_tensor_v_scale"),                    \
          py::arg("k_out"),                                 \
          py::arg("v_out"),                                 \
          py::arg("return_kv"),                             \
          py::arg("use_shuffle_layout"),                    \
          py::arg("block_size"),                            \
          py::arg("x"),                                     \
          py::arg("rotary_dim") = 0);

#define FUSED_QKNORM_ROPE_CACHE_QUANT_PYBIND                    \
    m.def("fused_qk_norm_rope_cache_quant_shuffle",             \
          &aiter::fused_qk_norm_rope_cache_quant_shuffle);      \
    m.def("fused_qk_norm_rope_cache_pts_quant_shuffle",         \
          &aiter::fused_qk_norm_rope_cache_pts_quant_shuffle,   \
          py::arg("qkv"),                                       \
          py::arg("qw"),                                        \
          py::arg("kw"),                                        \
          py::arg("cos_sin"),                                   \
          py::arg("positions"),                                 \
          py::arg("num_tokens"),                                \
          py::arg("num_heads_q"),                               \
          py::arg("num_heads_k"),                               \
          py::arg("num_heads_v"),                               \
          py::arg("head_size"),                                 \
          py::arg("is_neox_style"),                             \
          py::arg("eps"),                                       \
          py::arg("q_out"),                                     \
          py::arg("k_cache"),                                   \
          py::arg("v_cache"),                                   \
          py::arg("slot_mapping"),                              \
          py::arg("per_tensor_k_scale"),                        \
          py::arg("per_tensor_v_scale"),                        \
          py::arg("k_out"),                                     \
          py::arg("v_out"),                                     \
          py::arg("return_kv"),                                 \
          py::arg("use_shuffle_layout"),                        \
          py::arg("block_size"),                                \
          py::arg("x"),                                         \
          py::arg("rotary_dim") = 0);                           \
    m.def("fused_qk_norm_rope_cache_block_quant_shuffle",       \
          &aiter::fused_qk_norm_rope_cache_block_quant_shuffle, \
          py::arg("qkv"),                                       \
          py::arg("num_heads_q"),                               \
          py::arg("num_heads_k"),                               \
          py::arg("num_heads_v"),                               \
          py::arg("head_dim"),                                  \
          py::arg("eps"),                                       \
          py::arg("q_weight"),                                  \
          py::arg("k_weight"),                                  \
          py::arg("cos_sin_cache"),                             \
          py::arg("is_neox"),                                   \
          py::arg("position_ids"),                              \
          py::arg("k_cache"),                                   \
          py::arg("v_cache"),                                   \
          py::arg("slot_mapping"),                              \
          py::arg("cu_q_len"),                                  \
          py::arg("kv_cache_dtype"),                            \
          py::arg("k_scale"),                                   \
          py::arg("v_scale"),                                   \
          py::arg("max_tokens_per_batch") = 0);                 \
    m.def("fused_qk_norm_rope_2way", &aiter::fused_qk_norm_rope_2way);

#define SMOOTHQUANT_PYBIND                        \
      m.def("smoothquant_fwd", &smoothquant_fwd); \
      m.def("moe_smoothquant_fwd", &moe_smoothquant_fwd);

#define HIPBSOLGEMM_PYBIND                                                           \
      m.def("hipb_create_extension", &hipb_create_extension, "create_extension");    \
      m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); \
      m.def("hipb_mm", &hipb_mm, "hipb_mm", py::arg("mat1"), py::arg("mat2"),        \
            py::arg("solution_index"), py::arg("bias") = std::nullopt,               \
            py::arg("out_dtype") = std::nullopt, py::arg("scaleA") = std::nullopt,   \
            py::arg("scaleB") = std::nullopt, py::arg("scaleOut") = std::nullopt,    \
            py::arg("scaleType") = std::nullopt);                                    \
      m.def("hipb_findallsols", &hipb_findallsols, "hipb_findallsols",               \
            py::arg("mat1"), py::arg("mat2"), py::arg("bias") = std::nullopt,        \
            py::arg("out_dtype") = std::nullopt, py::arg("scaleA") = std::nullopt,   \
            py::arg("scaleB") = std::nullopt, py::arg("scaleC") = std::nullopt,      \
            py::arg("scaleType") = std::nullopt);                                    \
      m.def("getHipblasltKernelName", &getHipblasltKernelName);

#define ROCSOLGEMM_PYBIND                                                            \
      m.def("rocb_create_extension", &rocb_create_extension, "create_extension");    \
      m.def("rocb_destroy_extension", &rocb_destroy_extension, "destroy_extension"); \
      m.def("rocb_mm", &RocSolIdxBlas, "mm");                                        \
      m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols");

#define AITER_ENUM_PYBIND                                \
    pybind11::enum_<QuantType>(m, "QuantType")           \
        .value("No", QuantType::No)                      \
        .value("per_Tensor", QuantType::per_Tensor)      \
        .value("per_Token", QuantType::per_Token)        \
        .value("per_1x32", QuantType::per_1x32)          \
        .value("per_1x128", QuantType::per_1x128)        \
        .value("per_128x128", QuantType::per_128x128)    \
        .export_values();                                \
    pybind11::enum_<ActivationType>(m, "ActivationType") \
        .value("No", ActivationType::No)                 \
        .value("Silu", ActivationType::Silu)             \
        .value("Gelu", ActivationType::Gelu)             \
        .export_values();                                \
    pybind11::implicitly_convertible<int, QuantType>();  \
    pybind11::implicitly_convertible<int, ActivationType>();

#define TOPK_PLAIN_PYBIND                         \
    m.def("topk_plain",                           \
          &topk_plain,                            \
          py::arg("values"),                      \
          py::arg("topk_ids"),                    \
          py::arg("topk_out"),                    \
          py::arg("topk"),                        \
          py::arg("largest")   = true,            \
          py::arg("rowStarts") = torch::Tensor(), \
          py::arg("rowEnds")   = torch::Tensor(), \
          py::arg("stride0")   = -1,              \
          py::arg("stride1")   = 1);

#define TOPK_TRANSFORM_PYBIND                         \
    m.def("fast_topk_interface",                      \
          &fast_topk_interface,                       \
          py::arg("score"),                           \
          py::arg("indices"),                         \
          py::arg("lengths"),                         \
          py::arg("row_starts_opt") = std::nullopt);  \
    m.def("fast_topk_transform_interface",            \
          &fast_topk_transform_interface,             \
          py::arg("score"),                           \
          py::arg("lengths"),                         \
          py::arg("dst_page_table"),                  \
          py::arg("src_page_table"),                  \
          py::arg("cu_seqlens_q"),                    \
          py::arg("row_starts_opt") = std::nullopt);  \
    m.def("fast_topk_transform_ragged_interface",     \
          &fast_topk_transform_ragged_interface,      \
          py::arg("score"),                           \
          py::arg("lengths"),                         \
          py::arg("topk_indices_ragged"),             \
          py::arg("topk_indices_offset"),             \
          py::arg("row_starts_opt") = std::nullopt);

#define MOE_C_PYBIND                                                                 \
      m.def("moe_c_moe_gemm_marlin_w8a8",                                                  \
      &moe_c_moe_gemm_marlin_w8a8,                                                         \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("a_scale"),                                                            \
      py::arg("b_scale"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
      py::arg("delta"),                                                              \
      py::arg("size_m")                                                              \
      );                                                                              \
      m.def("moe_c_moe_gemm_marlin_w8a8_tensorwise",                                       \
      &moe_c_moe_gemm_marlin_w8a8_tensorwise,                                              \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("a_scale"),                                                            \
      py::arg("b_scale"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
      py::arg("delta"),                                                              \
      py::arg("size_m")                                                              \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
      );                                                                              \
      m.def("moe_c_moe_gemm_marlin_w4a8",                                                  \
      &moe_c_moe_gemm_marlin_w4a8,                                                         \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("a_scale"),                                                            \
      py::arg("b_scale"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
1300
1301
      py::arg("delta"),                                                              \
      py::arg("size_m")                                                              \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
      );                                                                              \
      m.def("moe_c_moe_gemm_marlin_w8a8_fp8",                                                  \
      &moe_c_moe_gemm_marlin_w8a8_fp8,                                                         \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("a_scale"),                                                            \
      py::arg("b_scale"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
      py::arg("delta"),                                                              \
      py::arg("size_m")                                                              \
      );                                                                              \
      m.def("moe_c_moe_gemm_marlin_w8a8_fp8_tensorwise",                                    \
      &moe_c_moe_gemm_marlin_w8a8_fp8_tensorwise,                                           \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("a_scale"),                                                            \
      py::arg("b_scale"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
      py::arg("delta"),                                                              \
      py::arg("size_m")                                                              \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
      );                                                                              \
      m.def("moe_c_moe_gemm_marlin_w4a16",                                                  \
      &moe_c_moe_gemm_marlin_w4a16,                                                         \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("b_scale"),                                                            \
      py::arg("b_zeros"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
      py::arg("delta")                                                              \
      );                                                                               \
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
      m.def("moe_c_moe_gemm_marlin_w8a16",                                                  \
      &moe_c_moe_gemm_marlin_w8a16,                                                         \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("b_scale"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
      py::arg("delta")                                                              \
      );                                                                             \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
      m.def("moe_c_moe_w8a16_gemm_block_wise",                                                  \
        &moe_c_moe_w8a16_gemm_block_wise,                                                   \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                             \
        py::arg("topk_weights")  ,                                                        \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_w8a16_gemm_awq ==================== */                 \
  m.def("moe_c_moe_w8a16_gemm_awq",                                                         \
        &moe_c_moe_w8a16_gemm_awq,                                                          \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                             \
        py::arg("topk_weights") ,                                                         \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_wna16_gemm ==================== */                     \
  m.def("moe_c_moe_wna16_gemm",                                                             \
        &moe_c_moe_wna16_gemm,                                                              \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros")  ,                                                           \
        py::arg("topk_weights") ,                                                    \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_wna16_gemm_2 ==================== */                    \
  m.def("moe_c_moe_wna16_gemm_2",                                                           \
        &moe_c_moe_wna16_gemm_2,                                                            \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                        \
        py::arg("topk_weights"),                                                          \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
                                                                                          \
  /* ==================== moe_align_block_size ==================== */               \
  m.def("moe_c_moe_align_block_size",                                                       \
        &moe_c_moe_align_block_size,                                                        \
        py::arg("topk_ids"),                                                          \
        py::arg("num_experts"),                                                       \
        py::arg("block_size"),                                                        \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("experts_ids"),                                                       \
        py::arg("num_tokens_post_pad")                                                \
  );                                                                                  \
                                                                                      \
                                                                                      \
  /* ==================== moe_wna16_gemm_base ==================== */                \
  m.def("moe_c_moe_wna16_gemm_base",                                                        \
        &moe_c_moe_wna16_gemm_base,                                                         \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros"),                                                          \
        py::arg("topk_weights"),                                                     \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== sgl_moe_align_block_size ==================== */           \
  m.def("moe_c_sgl_moe_align_block_size",                                                   \
        &moe_c_sgl_moe_align_block_size,                                                    \
        py::arg("topk_ids"),                                                          \
        py::arg("num_experts"),                                                       \
        py::arg("block_size"),                                                        \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("experts_ids"),                                                       \
        py::arg("num_tokens_post_pad")                                                \
  );                                                                                  \
                                                                                      \
\
                                                                                      \
  /* ==================== moe_w8a8_gemm_block_wise ==================== */           \
  m.def("moe_c_moe_w8a8_gemm_block_wise",                                                  \
        &moe_c_moe_w8a8_gemm_block_wise,                                                   \
        py::arg("input"),                                                             \
        py::arg("a_scales"),                                                          \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                           \
        py::arg("topk_weights"),                                                        \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_w8a8_gemm_block_wise_kernel2 ==================== */   \
  m.def("moe_c_moe_w8a8_gemm_block_wise_kernel2",                                          \
        &moe_c_moe_w8a8_gemm_block_wise_kernel2,                                           \
        py::arg("input"),                                                             \
        py::arg("a_scales"),                                                          \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                           \
        py::arg("topk_weights") ,                                                   \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_w8a8_gemm_block_wise_fp8 ==================== */       \
  m.def("moe_c_moe_w8a8_gemm_block_wise_fp8",                                              \
        &moe_c_moe_w8a8_gemm_block_wise_fp8,                                               \
        py::arg("input"),                                                             \
        py::arg("a_scales"),                                                          \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                        \
        py::arg("topk_weights"),                                                       \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_w8a8_gemm_block_wise_kernel2_fp8 ==================== */ \
  m.def("moe_c_moe_w8a8_gemm_block_wise_kernel2_fp8",                                      \
        &moe_c_moe_w8a8_gemm_block_wise_kernel2_fp8,                                       \
        py::arg("input"),                                                             \
        py::arg("a_scales"),                                                          \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                       \
        py::arg("topk_weights"),                                                   \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );    \
  m.def("moe_c_topk_softmax",                                                               \
        &moe_c_topk_softmax,                                                                \
        py::arg("topk_weights"),                                                      \
        py::arg("topk_indices"),                                                      \
        py::arg("token_expert_indices"),                                              \
        py::arg("gating_output")                                                      \
1591
  );                                                                                  \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1592
/* ==================== silu_and_mul ==================== */                       \
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
      m.def("moe_c_silu_and_mul",\
            &moe_c_silu_and_mul,\
            py::arg("out"),\
            py::arg("input"),\
            py::arg("rows_per_block") = 1,\
            py::arg("vec_size") = 2);\
      m.def("moe_c_moe_sum_opt_v2",\
            &moe_c_moe_sum_opt_v2,\
            py::arg("input"),\
            py::arg("output"),\
            py::arg("routed_scaling_factor") = 1.0);

  #define MHC_PYBIND                            \
    m.def("mhc_pre_gemm_sqrsum",                \
          &aiter::mhc_pre_gemm_sqrsum,          \
          "mhc_pre_gemm_sqrsum",                \
          py::arg("out"),                       \
          py::arg("sqrsum"),                    \
          py::arg("x"),                         \
          py::arg("fn"),                        \
          py::arg("tile_k") = 128,              \
          py::arg("use_tf32") = false);         \
    m.def("mhc_pre_gemm_sqrsum_stage1_m128",    \
          &aiter::mhc_pre_gemm_sqrsum_stage1_m128, \
          "mhc_pre_gemm_sqrsum_stage1_m128",    \
          py::arg("out"),                       \
          py::arg("sqrsum"),                    \
          py::arg("x"),                         \
          py::arg("fn"),                        \
          py::arg("use_tf32") = false);         \
    m.def("mhc_pre_reduce_splitk",              \
          &aiter::mhc_pre_reduce_splitk,        \
          "mhc_pre_reduce_splitk",              \
          py::arg("out_red"),                   \
          py::arg("sqrsum_red"),                \
          py::arg("out"),                       \
          py::arg("sqrsum"));                   \
    m.def("mhc_pre_big_fuse",                   \
          &aiter::mhc_pre_big_fuse,             \
          "mhc_pre_big_fuse",                   \
          py::arg("post_mix"),                  \
          py::arg("comb_mix"),                  \
          py::arg("layer_input"),               \
          py::arg("gemm_out_mul"),              \
          py::arg("gemm_out_sqrsum"),           \
          py::arg("hc_scale"),                  \
          py::arg("hc_base"),                   \
          py::arg("residual"),                  \
          py::arg("rms_eps")            = 1e-6, \
          py::arg("hc_pre_eps")         = 1e-6, \
          py::arg("hc_sinkhorn_eps")    = 1e-6, \
          py::arg("hc_post_mult_value") = 1.0,  \
          py::arg("sinkhorn_repeat")    = 20);  \
    m.def("mhc_pre_big_fuse_tlstyle",           \
          &aiter::mhc_pre_big_fuse_tlstyle,     \
          "mhc_pre_big_fuse_tlstyle",           \
          py::arg("post_mix"),                  \
          py::arg("comb_mix"),                  \
          py::arg("layer_input"),               \
          py::arg("gemm_out_mul"),              \
          py::arg("gemm_out_sqrsum"),           \
          py::arg("hc_scale"),                  \
          py::arg("hc_base"),                   \
          py::arg("residual"),                  \
          py::arg("rms_eps")            = 1e-6, \
          py::arg("hc_pre_eps")         = 1e-6, \
          py::arg("hc_sinkhorn_eps")    = 1e-6, \
          py::arg("hc_post_mult_value") = 1.0,  \
          py::arg("sinkhorn_repeat")    = 20);  \
    m.def("mhc_post",                           \
          &aiter::mhc_post,                     \
          "mhc_post",                           \
          py::arg("out"),                       \
          py::arg("x"),                         \
          py::arg("residual"),                  \
          py::arg("post_layer_mix"),            \
          py::arg("comb_res_mix"));