rocm_ops.hpp 101 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
// SPDX-License-Identifier: MIT

#define ACTIVATION_PYBIND                                                                               \
      m.def("silu_and_mul", &aiter::silu_and_mul, "Activation function used in SwiGLU.",                \
            py::arg("out"), py::arg("input"));                                                          \
      m.def("scaled_silu_and_mul", &aiter::scaled_silu_and_mul, "Activation function used in scaled SwiGLU.",\
            py::arg("out"), py::arg("input"), py::arg("scale"));                                             \
      m.def("gelu_and_mul", &aiter::gelu_and_mul, "Activation function used in GELU.",                       \
            py::arg("out"), py::arg("input"));                                                               \
      m.def("gelu_tanh_and_mul", &aiter::gelu_tanh_and_mul, "Activation function used in GELU tanh.",        \
            py::arg("out"), py::arg("input"));

#define AITER_OPERATOR_PYBIND                                                   \
    m.def("add", &aiter_add, "apply for add with transpose and broadcast.");    \
    m.def("mul", &aiter_mul, "apply for mul with transpose and broadcast.");    \
    m.def("sub", &aiter_sub, "apply for sub with transpose and broadcast.");    \
    m.def("div", &aiter_div, "apply for div with transpose and broadcast.");    \
    m.def("add_", &aiter_add_, "apply for add_ with transpose and broadcast."); \
    m.def("mul_", &aiter_mul_, "apply for mul_ with transpose and broadcast."); \
    m.def("sub_", &aiter_sub_, "apply for sub_ with transpose and broadcast."); \
    m.def("div_", &aiter_div_, "apply for div_ with transpose and broadcast.");
#define AITER_UNARY_PYBIND                                  \
    m.def("sigmoid", &aiter_sigmoid, "apply for sigmoid."); \
    m.def("tanh", &aiter_tanh, "apply for tanh.");

#define ATTENTION_ASM_MLA_PYBIND                                                                  \
      m.def("mla_decode_stage1_asm_fwd", &mla_decode_stage1_asm_fwd, "mla_decode_stage1_asm_fwd", \
            py::arg("Q"),                                                                         \
            py::arg("KV"),                                                                        \
            py::arg("qo_indptr"),                                                                 \
            py::arg("kv_indptr"),                                                                 \
            py::arg("kv_page_indices"),                                                           \
            py::arg("kv_last_page_lens"),                                                         \
            py::arg("max_seqlen_q"),                                                              \
            py::arg("softmax_scale"),                                                             \
            py::arg("splitData"),                                                                 \
            py::arg("splitLse"));                                                                 \
      m.def("mla_prefill_asm_fwd", &mla_prefill_asm_fwd, "mla_prefill_asm_fwd",                   \
            py::arg("Q"),                                                                         \
            py::arg("KV"),                                                                        \
            py::arg("qo_indptr"),                                                                 \
            py::arg("kv_indptr"),                                                                 \
            py::arg("kv_page_indices"),                                                           \
            py::arg("kv_last_page_lens"),                                                         \
            py::arg("max_seqlen_q"),                                                              \
            py::arg("softmax_scale"),                                                             \
            py::arg("splitData"),                                                                 \
            py::arg("splitLse"));

#define ATTENTION_ASM_PYBIND                    \
      m.def("pa_fwd_asm", &pa_fwd, "pa_fwd",    \
            py::arg("Q"),                       \
            py::arg("K"),                       \
            py::arg("V"),                       \
            py::arg("block_tables"),            \
            py::arg("context_lens"),            \
            py::arg("max_num_blocks"),          \
            py::arg("K_QScale") = std::nullopt, \
            py::arg("V_QScale") = std::nullopt, \
            py::arg("out_") = std::nullopt,     \
            py::arg("high_precision") = 1);

#define ATTENTION_CK_PYBIND                                \
      m.def("pa_fwd_naive", &pa_fwd_naive, "pa_fwd_naive", \
            py::arg("Q"),                                  \
            py::arg("K"),                                  \
            py::arg("V"),                                  \
            py::arg("block_tables"),                       \
            py::arg("context_lens"),                       \
            py::arg("k_dequant_scales"),                   \
            py::arg("v_dequant_scales"),                   \
            py::arg("max_seq_len"),                        \
            py::arg("num_kv_heads"),                       \
            py::arg("scale_s"),                            \
            py::arg("scale_k"),                            \
            py::arg("scale_v"),                            \
            py::arg("block_size"),                         \
            py::arg("quant_algo"),                         \
            py::arg("out_") = std::nullopt);

#define ATTENTION_PYBIND                                            \
      m.def("paged_attention_rocm", &paged_attention,               \
            "paged_attention_rocm(Tensor! out, Tensor exp_sums,"    \
            "                Tensor max_logits, Tensor tmp_out,"    \
            "                Tensor query, Tensor key_cache,"       \
            "                Tensor value_cache, int num_kv_heads," \
            "                float scale, Tensor block_tables,"     \
            "                Tensor context_lens, int block_size,"  \
            "                int max_context_len,"                  \
            "                Tensor? alibi_slopes,"                 \
            "                str kv_cache_dtype,"                   \
            "                float k_scale, float v_scale) -> ()");

#define ATTENTION_RAGGED_PYBIND                                     \
      m.def("paged_attention_ragged", &paged_attention_ragged,      \
            "paged_attention_ragged(Tensor! out, Tensor exp_sums,"  \
            "                Tensor max_logits, Tensor tmp_out,"    \
            "                Tensor query, Tensor key_cache,"       \
            "                Tensor value_cache, int num_kv_heads," \
            "                float scale, Tensor block_tables,"     \
            "                Tensor context_lens, int block_size,"  \
            "                int max_context_len,"                  \
            "                Tensor? alibi_slopes,"                 \
            "                str kv_cache_dtype,"                   \
            "                float k_scale, float v_scale) -> ()");

#define BATCHED_GEMM_A8W8_PYBIND                                                                        \
      m.def("batched_gemm_a8w8", &batched_gemm_a8w8, "batched_gemm_a8w8", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"),                                     \
            py::arg("bias") = std::nullopt, py::arg("splitK") = 0);

#define BATCHED_GEMM_A8W8_TUNE_PYBIND                                                                                  \
      m.def("batched_gemm_a8w8_tune", &batched_gemm_a8w8_tune, "batched_gemm_a8w8_tune", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0,                           \
            py::arg("splitK") = 0);

#define CACHE_PYBIND                                                                         \
      m.def("swap_blocks", &swap_blocks,                                                     \
            "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");             \
      m.def("copy_blocks", &copy_blocks,                                                     \
            "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "               \
            "Tensor block_mapping) -> ()");                                                  \
                                                                                             \
      m.def("reshape_and_cache", &reshape_and_cache,                                         \
            "reshape_and_cache");                                                            \
      m.def("reshape_and_cache_flash", &reshape_and_cache_flash,                             \
            "reshape_and_cache_flash(Tensor key, Tensor value,"                              \
            "                        Tensor! key_cache,"                                     \
            "                        Tensor! value_cache,"                                   \
            "                        Tensor slot_mapping,"                                   \
            "                        str kv_cache_dtype,"                                    \
            "                        float k_scale, float v_scale) -> ()");                  \
      m.def("reshape_and_cache_with_pertoken_quant", &reshape_and_cache_with_pertoken_quant, \
            "reshape_and_cache_with_pertoken_quant(Tensor key, Tensor value,"                \
            "                        Tensor! key_cache,"                                     \
            "                        Tensor! value_cache,"                                   \
            "                        Tensor! k_dequant_scales,"                              \
            "                        Tensor! v_dequant_scales,"                              \
            "                        Tensor slot_mapping) -> ()");                           \
      m.def("reshape_and_cache_with_block_quant", &reshape_and_cache_with_block_quant,       \
            "reshape_and_cache_with_block_quant(Tensor key, Tensor value,"                   \
            "                        Tensor! key_cache,"                                     \
            "                        Tensor! value_cache,"                                   \
            "                        Tensor! k_dequant_scales,"                              \
            "                        Tensor! v_dequant_scales,"                              \
            "                        Tensor slot_mapping,"                                   \
            "                        const bool asm_layout) -> ()");                         \
      m.def("convert_fp8", &convert_fp8,                                                     \
            "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "                 \
            "str kv_cache_dtype) -> ()");

#define CUSTOM_ALL_REDUCE_PYBIND                                                               \
    m.def("init_custom_ar",                                                                    \
          &aiter::init_custom_ar,                                                              \
          "init_custom_ar(Tensor meta, Tensor rank_data, "                                     \
          "str[] handles, int[] offsets, int rank, "                                           \
          "bool fully_connected) -> int",                                                      \
          py::arg("meta"),                                                                     \
          py::arg("rank_data"),                                                                \
          py::arg("handles"),                                                                  \
          py::arg("offsets"),                                                                  \
          py::arg("rank"),                                                                     \
          py::arg("fully_connected"));                                                         \
    m.def("all_reduce",                                                                        \
          &aiter::all_reduce,                                                                  \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("out"),                                                                      \
          py::arg("open_fp8_quant"),                                                           \
          py::arg("reg_buffer") = std::nullopt);                                               \
    m.def("all_gather_reg",                                                                    \
          &aiter::all_gather_reg,                                                              \
          "all_gather_reg(int fa, Tensor inp, Tensor! out) -> ()",                             \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("out"));                                                                     \
    m.def("all_gather_unreg",                                                                  \
          &aiter::all_gather_unreg,                                                            \
          "all_gather_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> ()",        \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("reg_buffer"),                                                               \
          py::arg("out"));                                                                     \
    m.def("fused_allreduce_rmsnorm",                                                           \
          &aiter::fused_allreduce_rmsnorm,                                                     \
          py::arg("_fa"),                                                                      \
          py::arg("inp"),                                                                      \
          py::arg("res_inp"),                                                                  \
          py::arg("res_out"),                                                                  \
          py::arg("out"),                                                                      \
          py::arg("w"),                                                                        \
          py::arg("eps"),                                                                      \
          py::arg("reg_buffer") = std::nullopt);                                               \
    m.def("dispose", &aiter::dispose, py::arg("_fa"));                                         \
    m.def("meta_size", &aiter::meta_size);                                                     \
    m.def("register_buffer",                                                                   \
          &aiter::register_buffer,                                                             \
          "register_buffer(int fa, Tensor t, str[] handles, int[] offsets) -> ()",             \
          py::arg("_fa"),                                                                      \
          py::arg("t"),                                                                        \
          py::arg("handles"),                                                                  \
          py::arg("offsets"));                                                                 \
    m.def("get_graph_buffer_ipc_meta", &aiter::get_graph_buffer_ipc_meta, py::arg("_fa"));     \
    m.def("register_graph_buffers",                                                            \
          &aiter::register_graph_buffers,                                                      \
          py::arg("_fa"),                                                                      \
          py::arg("handles"),                                                                  \
          py::arg("offsets"));                                                                 \
    m.def("allocate_meta_buffer", &aiter::allocate_meta_buffer, py::arg("size"));              \
    m.def("get_meta_buffer_ipc_handle", &aiter::get_meta_buffer_ipc_handle, py::arg("inp"));



#define CUSTOM_PYBIND                                                                                 \
      m.def("wvSpltK", &wvSpltK, "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"         \
                                 "        int CuCount) -> ()");                                       \
      m.def("LLMM1", &LLMM1, "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " \
                             "()");

#define GEMM_A8W8_ASM_PYBIND                                              \
      m.def("gemm_a8w8_asm", &gemm_a8w8_asm,                              \
            "Asm gemm a8w8 ,  weight should be shuffle to layout(32,16)", \
            py::arg("XQ"), py::arg("WQ"),                                 \
            py::arg("x_scale"), py::arg("w_scale"),                       \
            py::arg("Out"), py::arg("bias"),                              \
            py::arg("sub_m") = 128, py::arg("sub_n") = 128,               \
            py::arg("pad_a") = 0, py::arg("pad_b") = 0,                   \
            py::arg("pad_c") = 0, py::arg("splitK") = 0);

#define GEMM_A8W8_BLOCKSCALE_PYBIND                                                                             \
      m.def("gemm_a8w8_blockscale", &gemm_a8w8_blockscale, "fp8 blockscale gemm", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"));

#define GEMM_A8W8_BLOCKSCALE_TUNE_PYBIND                                                                                        \
      m.def("gemm_a8w8_blockscale_tune", &gemm_a8w8_blockscale_tune, "gemm_a8w8_blockscale_tune", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0,                                    \
            py::arg("splitK") = 0);

#define GEMM_A8W8_PYBIND                                                        \
      m.def("gemm_a8w8", &gemm_a8w8, "gemm_a8w8", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"),             \
            py::arg("bias") = std::nullopt, py::arg("splitK") = 0);

#define GEMM_A8W8_TUNE_PYBIND                                                                  \
      m.def("gemm_a8w8_tune", &gemm_a8w8_tune, "gemm_a8w8_tune", py::arg("XQ"), py::arg("WQ"), \
            py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0,   \
            py::arg("splitK") = 0);

#define MHA_BWD_ASM_PYBIND                                  \
      m.def("fmha_v3_bwd", &aiter::torch_itfs::fmha_v3_bwd, \
            py::arg("dout"),                                \
            py::arg("q"), py::arg("k"), py::arg("v"),       \
            py::arg("out"),                                 \
            py::arg("softmax_lse"),                         \
            py::arg("dropout_p"),                           \
            py::arg("softmax_scale"),                       \
            py::arg("is_causal"),                           \
            py::arg("window_size_left"),                    \
            py::arg("window_size_right"),                   \
            py::arg("deterministic"),                       \
            py::arg("is_v3_atomic_fp32"),                   \
            py::arg("how_v3_bf16_cvt"),                     \
            py::arg("dq") = std::nullopt,                   \
            py::arg("dk") = std::nullopt,                   \
            py::arg("dv") = std::nullopt,                   \
            py::arg("alibi_slopes") = std::nullopt,         \
            py::arg("rng_state") = std::nullopt,            \
            py::arg("gen") = std::nullopt);

#define MHA_VARLEN_BWD_ASM_PYBIND                                         \
      m.def("fmha_v3_varlen_bwd", &aiter::torch_itfs::fmha_v3_varlen_bwd, \
            py::arg("dout"),                                              \
            py::arg("q"), py::arg("k"), py::arg("v"),                     \
            py::arg("out"),                                               \
            py::arg("softmax_lse"),                                       \
            py::arg("cu_seqlens_q"),                                      \
            py::arg("cu_seqlens_k"),                                      \
            py::arg("max_seqlen_q"),                                      \
            py::arg("max_seqlen_k"),                                      \
            py::arg("dropout_p"),                                         \
            py::arg("softmax_scale"),                                     \
            py::arg("zero_tensors"),                                      \
            py::arg("is_causal"),                                         \
            py::arg("window_size_left"),                                  \
            py::arg("window_size_right"),                                 \
            py::arg("deterministic"),                                     \
            py::arg("is_v3_atomic_fp32"),                                 \
            py::arg("how_v3_bf16_cvt"),                                   \
            py::arg("dq") = std::nullopt,                                 \
            py::arg("dk") = std::nullopt,                                 \
            py::arg("dv") = std::nullopt,                                 \
            py::arg("alibi_slopes") = std::nullopt,                       \
            py::arg("rng_state") = std::nullopt,                          \
            py::arg("gen") = std::nullopt);

#define MHA_BWD_PYBIND                                \
      m.def("mha_bwd", &aiter::torch_itfs::mha_bwd,   \
            py::arg("dout"),                          \
            py::arg("q"), py::arg("k"), py::arg("v"), \
            py::arg("out"),                           \
            py::arg("softmax_lse"),                   \
            py::arg("dropout_p"),                     \
            py::arg("softmax_scale"),                 \
            py::arg("is_causal"),                     \
            py::arg("window_size_left"),              \
            py::arg("window_size_right"),             \
            py::arg("deterministic"),                 \
            py::arg("dq") = std::nullopt,             \
            py::arg("dk") = std::nullopt,             \
            py::arg("dv") = std::nullopt,             \
            py::arg("dbias") = std::nullopt,          \
            py::arg("bias") = std::nullopt,           \
            py::arg("alibi_slopes") = std::nullopt,   \
            py::arg("rng_state") = std::nullopt,      \
            py::arg("gen") = std::nullopt);

#define MHA_FWD_ASM_PYBIND                                  \
      m.def("fmha_v3_fwd", &aiter::torch_itfs::fmha_v3_fwd, \
            py::arg("q"), py::arg("k"), py::arg("v"),       \
            py::arg("dropout_p"),                           \
            py::arg("softmax_scale"),                       \
            py::arg("is_causal"),                           \
            py::arg("window_size_left"),                    \
            py::arg("window_size_right"),                   \
            py::arg("return_softmax_lse"),                  \
            py::arg("return_dropout_randval"),              \
            py::arg("out") = std::nullopt,                  \
            py::arg("bias") = std::nullopt,                 \
            py::arg("alibi_slopes") = std::nullopt,         \
            py::arg("gen") = std::nullopt);

#define MHA_FWD_PYBIND                                \
      m.def("mha_fwd", &aiter::torch_itfs::mha_fwd,   \
            py::arg("q"), py::arg("k"), py::arg("v"), \
            py::arg("dropout_p"),                     \
            py::arg("softmax_scale"),                 \
            py::arg("is_causal"),                     \
            py::arg("window_size_left"),              \
            py::arg("window_size_right"),             \
            py::arg("return_softmax_lse"),            \
            py::arg("return_dropout_randval"),        \
            py::arg("out") = std::nullopt,            \
            py::arg("bias") = std::nullopt,           \
            py::arg("alibi_slopes") = std::nullopt,   \
            py::arg("gen") = std::nullopt);

#define MHA_VARLEN_BWD_PYBIND                                     \
      m.def("mha_varlen_bwd", &aiter::torch_itfs::mha_varlen_bwd, \
            py::arg("dout"),                                      \
            py::arg("q"), py::arg("k"), py::arg("v"),             \
            py::arg("out"),                                       \
            py::arg("softmax_lse"),                               \
            py::arg("cu_seqlens_q"),                              \
            py::arg("cu_seqlens_k"),                              \
            py::arg("max_seqlen_q"),                              \
            py::arg("max_seqlen_k"),                              \
            py::arg("dropout_p"),                                 \
            py::arg("softmax_scale"),                             \
            py::arg("zero_tensors"),                              \
            py::arg("is_causal"),                                 \
            py::arg("window_size_left"),                          \
            py::arg("window_size_right"),                         \
            py::arg("deterministic"),                             \
            py::arg("dq") = std::nullopt,                         \
            py::arg("dk") = std::nullopt,                         \
            py::arg("dv") = std::nullopt,                         \
            py::arg("alibi_slopes") = std::nullopt,               \
            py::arg("rng_state") = std::nullopt,                  \
            py::arg("gen") = std::nullopt);

#define MHA_VARLEN_FWD_PYBIND                                     \
      m.def("mha_varlen_fwd", &aiter::torch_itfs::mha_varlen_fwd, \
            py::arg("q"), py::arg("k"), py::arg("v"),             \
            py::arg("cu_seqlens_q"),                              \
            py::arg("cu_seqlens_k"),                              \
            py::arg("max_seqlen_q"),                              \
            py::arg("max_seqlen_k"),                              \
            py::arg("dropout_p"),                                 \
            py::arg("softmax_scale"),                             \
            py::arg("logits_soft_cap"),                           \
            py::arg("zero_tensors"),                              \
            py::arg("is_causal"),                                 \
            py::arg("window_size_left"),                          \
            py::arg("window_size_right"),                         \
            py::arg("return_softmax_lse"),                        \
            py::arg("return_dropout_randval"),                    \
            py::arg("out") = std::nullopt,                        \
            py::arg("block_table") = std::nullopt,                \
            py::arg("bias") = std::nullopt,                       \
            py::arg("alibi_slopes") = std::nullopt,               \
            py::arg("gen") = std::nullopt);

#define MHA_BATCH_PREFILL_PYBIND                      \
      m.def("mha_batch_prefill", &aiter::torch_itfs::mha_batch_prefill,  \
            py::arg("q"), py::arg("k"), py::arg("v"),                    \
            py::arg("cu_seqlens_q"),                                     \
            py::arg("kv_indptr"),                                        \
            py::arg("kv_page_indices"),                                  \
            py::arg("max_seqlen_q"),                                     \
            py::arg("max_seqlen_k"),                                     \
            py::arg("dropout_p"),                                        \
            py::arg("softmax_scale"),                                    \
            py::arg("logits_soft_cap"),                                  \
            py::arg("zero_tensors"),                                     \
            py::arg("is_causal"),                                        \
            py::arg("window_size_left"),                                 \
            py::arg("window_size_right"),                                \
            py::arg("return_softmax_lse"),                               \
            py::arg("return_dropout_randval"),                           \
            py::arg("out") = std::nullopt,                               \
            py::arg("bias") = std::nullopt,                              \
            py::arg("alibi_slopes") = std::nullopt,                      \
            py::arg("gen") = std::nullopt);

#define MOE_CK_2STAGES_PYBIND                          \
      m.def("ck_moe_stage1", &ck_moe_stage1,           \
            py::arg("hidden_states"),                  \
            py::arg("w1"),                             \
            py::arg("w2"),                             \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("out"),                            \
            py::arg("topk"),                           \
            py::arg("w1_scale") = std::nullopt,        \
            py::arg("a1_scale") = std::nullopt,        \
            py::arg("block_m") = 32,                   \
            py::arg("sorted_weights") = std::nullopt,  \
            py::arg("act_op") = 0);                    \
                                                       \
      m.def("ck_moe_stage2", &ck_moe_stage2,           \
            py::arg("inter_states"),                   \
            py::arg("w1"),                             \
            py::arg("w2"),                             \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("out"),                            \
            py::arg("topk"),                           \
            py::arg("w2_scale") = std::nullopt,        \
            py::arg("a2_scale") = std::nullopt,        \
            py::arg("block_m") = 32,                   \
            py::arg("sorted_weights") = std::nullopt); \

#define MOE_ASM_2STAGES_PYBIND                         \
      m.def("asm_fmoe_stage1", &asm_fmoe_stage1,       \
            py::arg("out"),                            \
            py::arg("input"),                          \
            py::arg("gate"),                           \
            py::arg("down"),                           \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_weights"),                 \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("top_k"),                          \
            py::arg("scale_a") = std::nullopt,         \
            py::arg("scale_b") = std::nullopt,         \
            py::arg("zero_points") = std::nullopt,     \
            py::arg("mode") = 0,                       \
            py::arg("solidx") = 0,                     \
            py::arg("block_size") = 16,                \
            py::arg("persist_groups") = 0);            \
                                                       \
      m.def("asm_fmoe_stage2", &asm_fmoe_stage2,       \
            py::arg("out"),                            \
            py::arg("input"),                          \
            py::arg("gate"),                           \
            py::arg("down"),                           \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_weights"),                 \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("top_k"),                          \
            py::arg("scale_a") = std::nullopt,         \
            py::arg("scale_b") = std::nullopt,         \
            py::arg("zero_points") = std::nullopt,     \
            py::arg("mode") = 0,                       \
            py::arg("solidx") = 0,                     \
            py::arg("block_size") = 16,                \
            py::arg("persist_groups") = 0);            \
                                                       \
      m.def("asm_fmoe_a8", &asm_fmoe_a8,               \
            py::arg("out"),                            \
            py::arg("input"),                          \
            py::arg("gate"),                           \
            py::arg("down"),                           \
            py::arg("sorted_token_ids"),               \
            py::arg("sorted_weights"),                 \
            py::arg("sorted_expert_ids"),              \
            py::arg("num_valid_ids"),                  \
            py::arg("top_k"),                          \
            py::arg("scale_a") = std::nullopt,         \
            py::arg("scale_b") = std::nullopt,         \
            py::arg("zero_points") = std::nullopt,     \
            py::arg("mode") = 0,                       \
            py::arg("solidx") = 0,                     \
            py::arg("out_type") = 0,                   \
            py::arg("persist_groups") = 0,             \
            py::arg("use_shuffle") = 0);               \
                                                       \
      m.def("asm_moe_get_solutions", &asm_moe_get_solutions,  \
            py::arg("hidden_states"),                         \
            py::arg("w1"),                                    \
            py::arg("w2"),                                    \
            py::arg("topk_weights"),                          \
            py::arg("topk_ids"),                              \
            py::arg("use_int8_w8a16") = false,                \
            py::arg("use_int4_w4a16") = false,                \
            py::arg("use_int8_w8a8") = false,                 \
            py::arg("use_int4_w4a8") = false,                 \
            py::arg("use_fp8_w8a8") = false,                  \
            py::arg("per_channel_quant") = false,             \
            py::arg("w1_zp") = std::nullopt,                  \
            py::arg("w2_zp") = std::nullopt,                  \
            py::arg("w1_scale") = std::nullopt,               \
            py::arg("w2_scale") = std::nullopt,               \
            py::arg("a1_scale") = std::nullopt,               \
            py::arg("a2_scale") = std::nullopt,               \
            py::arg("block_shape_n") = 0,                     \
            py::arg("block_shape_k") = 0,                     \
            py::arg("block_m") = 32,                          \
            py::arg("expert_mask") = std::nullopt);           \

#define AWQ_GEMM_ASM_PYBIND                          \
      m.def("awq_gemm_asm", &awq_gemm_asm,           \
            py::arg("out"),                          \
            py::arg("mat1"),                         \
            py::arg("mat2"),                         \
            py::arg("zero") = std::nullopt,          \
            py::arg("scalar") = std::nullopt );      \
      m.def("awq_gemm_asm_tuning", &awq_gemm_asm_tuning,        \
            py::arg("out"),                                     \
            py::arg("mat1"),                                    \
            py::arg("mat2"),                                    \
            py::arg("zero") = std::nullopt,                     \
            py::arg("scalar") = std::nullopt,                   \
            py::arg("solidx") = 0,                              \
            py::arg("jsonfile") = std::nullopt );                \

#define AWQ_DQ_ASM_PYBIND                          \
      m.def("awq_dq_asm", &awq_dq_asm,           \
            py::arg("out"),                  \
            py::arg("mat1"),                  \
            py::arg("zero") = std::nullopt,                  \
            py::arg("scalar") = std::nullopt );                            \


#define MOE_CK_PYBIND                                                               \
      m.def("ck_moe", &ck_moe,                                                      \
            py::arg("hidden_states"), py::arg("w1"), py::arg("w2"),                 \
            py::arg("topk_weights"), py::arg("topk_ids"),                           \
            py::arg("use_int8_w8a16") = false,                                      \
            py::arg("use_int4_w4a16") = false,                                      \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_int4_w4a8_block") = false,                                 \
            py::arg("w1_zp") = std::nullopt, py::arg("w2_zp") = std::nullopt,       \
            py::arg("w1_scale") = std::nullopt, py::arg("w2_scale") = std::nullopt, \
            py::arg("a1_scale") = std::nullopt, py::arg("a2_scale") = std::nullopt, \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("solution_id") = 0,                                             \
            py::arg("expert_mask") = std::nullopt);                                 \
      m.def("ck_shuffle_moe", &ck_shuffle_moe,                                      \
            py::arg("hidden_states"), py::arg("w1"), py::arg("w2"),                 \
            py::arg("topk_weights"), py::arg("topk_ids"),                           \
            py::arg("use_int8_w8a16") = false,                                      \
            py::arg("use_int4_w4a16") = false,                                      \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_int4_w4a8_block") = false,                                 \
            py::arg("w1_zp") = std::nullopt, py::arg("w2_zp") = std::nullopt,       \
            py::arg("w1_scale") = std::nullopt, py::arg("w2_scale") = std::nullopt, \
            py::arg("a1_scale") = std::nullopt, py::arg("a2_scale") = std::nullopt, \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("solution_id") = 0,                                             \
            py::arg("expert_mask") = std::nullopt);                                 \
      m.def("ck_moe_get_solutions", &ck_moe_get_solutions,                          \
            py::arg("hidden_states"), py::arg("w1"), py::arg("w2"),                 \
            py::arg("topk_weights"), py::arg("topk_ids"),                           \
            py::arg("use_int8_w8a16") = false,                                      \
            py::arg("use_int4_w4a16") = false,                                      \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_int4_w4a8_block") = false,                                 \
            py::arg("w1_zp") = std::nullopt, py::arg("w2_zp") = std::nullopt,       \
            py::arg("w1_scale") = std::nullopt, py::arg("w2_scale") = std::nullopt, \
            py::arg("a1_scale") = std::nullopt, py::arg("a2_scale") = std::nullopt, \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("expert_mask") = std::nullopt);                                 \
      m.def("ck_moe_stage_1", &ck_moe_stage_1,                                      \
            py::arg("hidden_states"),                                               \
            py::arg("w1"),                                                          \
            py::arg("w2"),                                                          \
            py::arg("sorted_token_ids"),                                            \
            py::arg("sorted_expert_ids"),                                           \
            py::arg("tokens_positions_per_expert"),                                 \
            py::arg("num_valid_ids"),                                               \
            py::arg("out"),                                                         \
            py::arg("topk"),                                                        \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_fp8_w8a8_block") = false,                                  \
            py::arg("w1_scale") = std::nullopt,                                     \
            py::arg("a1_scale") = std::nullopt,                                     \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("sorted_weights") = std::nullopt,                               \
            py::arg("act_op") = 0);                                                 \
      m.def("ck_moe_stage_2", &ck_moe_stage_2,                                      \
            py::arg("inter_states"),                                                \
            py::arg("w1"),                                                          \
            py::arg("w2"),                                                          \
            py::arg("sorted_token_ids"),                                            \
            py::arg("sorted_expert_ids"),                                           \
            py::arg("tokens_positions_per_expert"),                                 \
            py::arg("num_valid_ids"),                                               \
            py::arg("out"),                                                         \
            py::arg("topk"),                                                        \
            py::arg("use_int8_w8a8_block") = false,                                 \
            py::arg("use_fp8_w8a8_block") = false,                                  \
            py::arg("w2_scale") = std::nullopt,                                     \
            py::arg("a2_scale") = std::nullopt,                                     \
            py::arg("block_shape_n") = 0,                                           \
            py::arg("block_shape_k") = 0,                                           \
            py::arg("block_m") = 32,                                                \
            py::arg("sorted_weights") = std::nullopt);                              \
      m.def("ck_moe_per_token_quant", &ck_moe_per_token_quant,                      \
            py::arg("input"),                                                       \
            py::arg("out_quant"),                                                   \
            py::arg("out_scale"));                                                  \
            
#define MOE_UTILS_PYBIND                                                         \
      m.def("topk_softmax",                                                      \
            &aiter::topk_softmax,                                                \
            py::arg("topk_weights"),                                             \
            py::arg("topk_indices"),                                             \
            py::arg("token_expert_indices"),                                     \
            py::arg("gating_output"),                                            \
            py::arg("need_renorm"),                                              \
            "Apply topk softmax to the gating outputs.");                        \
      m.def("grouped_topk",                                                      \
            &grouped_topk,                                                       \
            py::arg("gating_output"),                                            \
            py::arg("topk_weights"),                                             \
            py::arg("topk_ids"),                                                 \
            py::arg("num_expert_group"),                                         \
            py::arg("topk_grp"),                                                 \
            py::arg("need_renorm"),                                              \
            py::arg("is_softmax")            = true,                             \
            py::arg("routed_scaling_factor") = 1.0f,                             \
            "Apply grouped topk softmax/sigmodd to the gating outputs.");        \
      m.def("biased_grouped_topk",                                               \
            &biased_grouped_topk,                                                \
            py::arg("gating_output"),                                            \
            py::arg("correction_bias"),                                          \
            py::arg("topk_weights"),                                             \
            py::arg("topk_ids"),                                                 \
            py::arg("num_expert_group"),                                         \
            py::arg("topk_grp"),                                                 \
            py::arg("need_renorm"),                                              \
            py::arg("routed_scaling_factor") = 1.0f,                             \
            "Apply biased grouped topk softmax to the gating outputs.");         \
      m.def("moe_sum", &aiter::moe_sum, "moe_sum(Tensor! input, Tensor output) -> ()"); \
      m.def("moe_fused_gate",                                                    \
            &moe_fused_gate,                                                     \
            py::arg("input"),                                                    \
            py::arg("bias"),                                                     \
            py::arg("topk_weights"),                                             \
            py::arg("topk_ids"),                                                 \
            py::arg("num_expert_group"),                                         \
            py::arg("topk_group"),                                               \
            py::arg("topk"),                                                     \
            py::arg("num_fused_shared_experts"),                                 \
            py::arg("routed_scaling_factor") = 1.0,                              \
            "Apply biased grouped topk softmax to the gating outputs.");         \
      m.def("moe_align_block_size", &moe_align_block_size,                       \
            "moe_align_block_size(Tensor topk_ids, int num_experts,"             \
            "                     int block_size, Tensor! sorted_token_ids,"     \
            "                     Tensor! experts_ids,"                          \
            "                     Tensor! num_tokens_post_pad) -> ()");          \
      m.def("sgl_moe_align_block_size", &sgl_moe_align_block_size,               \
            "sgl_moe_align_block_size(Tensor topk_ids, int num_experts,"         \
            "                         int block_size, Tensor! sorted_token_ids," \
            "                         Tensor! experts_ids,"                      \
            "                         Tensor! num_tokens_post_pad) -> ()");      \


#define MOE_OP_PYBIND                                                            \
      m.def("fmoe", &fmoe);                                                      \
      m.def("fmoe_int8_g1u0", &fmoe_int8_g1u0,                                   \
            py::arg("out"), py::arg("input"),                                    \
            py::arg("gate"), py::arg("down"),                                    \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),              \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("topk"), py::arg("input_scale"),                             \
            py::arg("fc1_scale"), py::arg("fc2_scale"),                          \
            py::arg("fc2_smooth_scale") = std::nullopt,                          \
            py::arg("activation") = ActivationType::Silu);                       \
      m.def("fmoe_g1u1", &fmoe_g1u1,                                             \
            py::arg("out"), py::arg("input"),                                    \
            py::arg("gate"), py::arg("down"),                                    \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),              \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("topk"), py::arg("input_scale"),                             \
            py::arg("fc1_scale"), py::arg("fc2_scale"),                          \
            py::arg("fc2_smooth_scale") = std::nullopt,                          \
            py::arg("activation") = ActivationType::Silu);                       \
      m.def("fmoe_g1u1_tkw1", &fmoe_g1u1_tkw1,                                   \
            py::arg("out"), py::arg("input"),                                    \
            py::arg("gate"), py::arg("down"),                                    \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),              \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("topk"), py::arg("input_scale"),                             \
            py::arg("fc1_scale"), py::arg("fc2_scale"),                          \
            py::arg("fc2_smooth_scale") = std::nullopt,                          \
            py::arg("activation") = ActivationType::Silu);                       \
      m.def("fmoe_int8_g1u0_a16", &fmoe_int8_g1u0_a16);                          \
      m.def("fmoe_g1u1_a16", &fmoe_g1u1_a16);                                    \
      m.def("fmoe_fp8_blockscale_g1u1", &fmoe_fp8_blockscale_g1u1,               \
            py::arg("out"), py::arg("input"),                                    \
            py::arg("gate"), py::arg("down"),                                    \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),              \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("topk"),                                                     \
            py::arg("input_scale"),                                              \
            py::arg("fc1_scale"), py::arg("fc2_scale"),                          \
            py::arg("fc_scale_blkn") = 128, py::arg("fc_scale_blkk") = 128,      \
            py::arg("fc2_smooth_scale") = std::nullopt,                          \
            py::arg("activation") = ActivationType::Silu);                       \
      m.def("moe_stage1_g1u1", &moe_stage1_g1u1,                                 \
            py::arg("input"),                                                    \
            py::arg("w1"), py::arg("w2"),                                        \
            py::arg("sorted_token_ids"),                                         \
            py::arg("sorted_expert_ids"), py::arg("num_valid_ids"),              \
            py::arg("out"),                                                      \
            py::arg("inter_dim"),                                                \
            py::arg("kernelName"),                                               \
            py::arg("block_m"),                                                  \
            py::arg("ksplit") = 0,                                               \
            py::arg("activation") = ActivationType::Silu,                        \
            py::arg("quant_type") = QuantType::No,                               \
            py::arg("a1_scale") = std::nullopt,                                  \
            py::arg("w1_scale") = std::nullopt,                                  \
            py::arg("sorted_weights") = std::nullopt);                           \

#define MOE_SUM_PYBIND                                              \
      m.def("asm_moe_sum", &asm_moe_sum, "asm_moe_sum(Tensor! input, Tensor output, Tensor sorted_ids) -> ()"); \

#define MOE_SORTING_PYBIND                                                          \
      m.def("moe_sorting_fwd", &moe_sorting_fwd,                                    \
            py::arg("topk_ids"), py::arg("topk_weights"),                           \
            py::arg("sorted_token_ids"), py::arg("sorted_weights"),                 \
            py::arg("sorted_expert_ids"), py::arg("tokens_positions_per_expert"),   \
            py::arg("num_valid_ids"), py::arg("moe_buf"), py::arg("num_experts"),   \
            py::arg("unit_size"), py::arg("local_expert_mask") = std::nullopt);

#define NORM_PYBIND                                                                      \
      m.def("layernorm2d_fwd", &layernorm2d,                                             \
            py::arg("input"), py::arg("weight"), py::arg("bias"),                        \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_add", &layernorm2d_with_add,                           \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("residual_in"), py::arg("residual_out"),                             \
            py::arg("weight"), py::arg("bias"),                                          \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_smoothquant", &layernorm2d_with_smoothquant,           \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("xscale"), py::arg("yscale"),                                        \
            py::arg("weight"), py::arg("bias"),                                          \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_add_smoothquant", &layernorm2d_with_add_smoothquant,   \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("residual_in"), py::arg("residual_out"),                             \
            py::arg("xscale"), py::arg("yscale"),                                        \
            py::arg("weight"), py::arg("bias"),                                          \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_dynamicquant", &layernorm2d_with_dynamicquant,         \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("yscale"), py::arg("weight"), py::arg("bias"),                       \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);                       \
      m.def("layernorm2d_fwd_with_add_dynamicquant", &layernorm2d_with_add_dynamicquant, \
            py::arg("out"), py::arg("input"),                                            \
            py::arg("residual_in"), py::arg("residual_out"),                             \
            py::arg("yscale"), py::arg("weight"), py::arg("bias"),                       \
            py::arg("epsilon"), py::arg("x_bias") = std::nullopt);
      // m.def("layernorm2d_with_add_asm", &layernorm2d_with_add_asm);                      \
      // m.def("layernorm2d_with_add_smoothquant_asm", &layernorm2d_with_add_smoothquant_asm);

#define POS_ENCODING_PYBIND                                                 \
      m.def("rotary_embedding_fwd", &rotary_embedding, "rotary_embedding"); \
      m.def("batched_rotary_embedding", &batched_rotary_embedding, "batched_rotary_embedding");

#define QUANT_PYBIND                                                     \
    m.def("static_per_tensor_quant", &aiter::static_per_tensor_quant);   \
    m.def("dynamic_per_tensor_quant", &aiter::dynamic_per_tensor_quant); \
    m.def("dynamic_per_token_scaled_quant",                              \
          &aiter::dynamic_per_token_scaled_quant,                        \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("scales"),                                             \
          py::arg("scale_ub")        = std::nullopt,                     \
          py::arg("shuffle_scale")   = false,                            \
          py::arg("num_rows")        = std::nullopt,                     \
          py::arg("num_rows_factor") = 1);                               \
    m.def("dynamic_per_group_scaled_quant_fp4",                          \
          &aiter::dynamic_per_group_scaled_quant_fp4,                    \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("scales"),                                             \
          py::arg("group_size")      = 32,                               \
          py::arg("shuffle_scale")   = true,                             \
          py::arg("num_rows")        = std::nullopt,                     \
          py::arg("num_rows_factor") = 1);                               \
    m.def("smooth_per_token_scaled_quant",                               \
          &aiter::smooth_per_token_scaled_quant,                         \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("scales"),                                             \
          py::arg("smooth_scale"),                                       \
          py::arg("smooth_scale_map") = std::nullopt,                    \
          py::arg("shuffle_scale")    = false,                           \
          py::arg("num_rows")         = std::nullopt,                    \
          py::arg("num_rows_factor")  = 1);                               \
    m.def("partial_transpose",                                           \
          &aiter::partial_transpose,                                     \
          py::arg("out"),                                                \
          py::arg("input"),                                              \
          py::arg("num_rows"));

#define RMSNORM_PYBIND                                                                             \
    m.def("rms_norm_cu",                                                                           \
          &rms_norm,                                                                               \
          "Apply Root Mean Square (RMS) Normalization to the input tensor.");                      \
    m.def(                                                                                         \
        "fused_add_rms_norm_cu", &fused_add_rms_norm, "In-place fused Add and RMS Normalization"); \
    m.def("rmsnorm2d_fwd",                                                                         \
          &rmsnorm2d,                                                                              \
          py::arg("input"),                                                                        \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"));                                             \
    m.def("rmsnorm2d_fwd_with_add",                                                                \
          &rmsnorm2d_with_add,                                                                     \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("residual_in"),                                                                  \
          py::arg("residual_out"),                                                                 \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"));                                             \
    m.def("rmsnorm2d_fwd_with_smoothquant",                                                        \
          &rmsnorm2d_with_smoothquant,                                                             \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("xscale"),                                                                       \
          py::arg("yscale"),                                                                       \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"));                                             \
    m.def("rmsnorm2d_fwd_with_add_smoothquant",                                                    \
          &rmsnorm2d_with_add_smoothquant,                                                         \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("residual_in"),                                                                  \
          py::arg("residual_out"),                                                                 \
          py::arg("xscale"),                                                                       \
          py::arg("yscale"),                                                                       \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"),                                                                      \
          py::arg("out_before_quant")            = std::nullopt);                                             \
    m.def("rmsnorm2d_fwd_with_dynamicquant",                                                       \
          &rmsnorm2d_with_dynamicquant,                                                            \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("yscale"),                                                                       \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"));                                             \
    m.def("rmsnorm2d_fwd_with_add_dynamicquant",                                                   \
          &rmsnorm2d_with_add_dynamicquant,                                                        \
          py::arg("out"),                                                                          \
          py::arg("input"),                                                                        \
          py::arg("residual_in"),                                                                  \
          py::arg("residual_out"),                                                                 \
          py::arg("yscale"),                                                                       \
          py::arg("weight"),                                                                       \
          py::arg("epsilon"));
#define ROPE_GENERAL_FWD_PYBIND                                   \
      m.def("rope_fwd_impl", &rope_fwd_impl);                     \
      m.def("rope_2c_fwd_impl", &rope_2c_fwd_impl);               \
      m.def("rope_cached_fwd_impl", &rope_cached_fwd_impl);       \
      m.def("rope_cached_2c_fwd_impl", &rope_cached_2c_fwd_impl); \
      m.def("rope_thd_fwd_impl", &rope_thd_fwd_impl);             \
      m.def("rope_2d_fwd_impl", &rope_2d_fwd_impl);

#define ROPE_GENERAL_BWD_PYBIND                                   \
      m.def("rope_bwd_impl", &rope_bwd_impl);                     \
      m.def("rope_2c_bwd_impl", &rope_2c_bwd_impl);               \
      m.def("rope_cached_bwd_impl", &rope_cached_bwd_impl);       \
      m.def("rope_cached_2c_bwd_impl", &rope_cached_2c_bwd_impl); \
      m.def("rope_thd_bwd_impl", &rope_thd_bwd_impl);             \
      m.def("rope_2d_bwd_impl", &rope_2d_bwd_impl);

#define ROPE_POS_FWD_PYBIND                                                                     \
      m.def("rope_cached_positions_fwd_impl", &rope_cached_positions_fwd_impl);                 \
      m.def("rope_cached_positions_2c_fwd_impl", &rope_cached_positions_2c_fwd_impl);           \
      m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl); \
      m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl);

#define FUSED_QKNORM_MROPE_CACHE_QUANT_PYBIND               \
    m.def("fused_qk_norm_mrope_3d_cache_pts_quant_shuffle", \
          &fused_qk_norm_mrope_3d_cache_pts_quant_shuffle,  \
          py::arg("qkv"),                                   \
          py::arg("qw"),                                    \
          py::arg("kw"),                                    \
          py::arg("cos_sin"),                               \
          py::arg("positions"),                             \
          py::arg("num_tokens"),                            \
          py::arg("num_heads_q"),                           \
          py::arg("num_heads_k"),                           \
          py::arg("num_heads_v"),                           \
          py::arg("head_size"),                             \
          py::arg("is_neox_style"),                         \
          py::arg("mrope_section_"),                        \
          py::arg("is_interleaved"),                        \
          py::arg("eps"),                                   \
          py::arg("q_out"),                                 \
          py::arg("k_cache"),                               \
          py::arg("v_cache"),                               \
          py::arg("slot_mapping"),                          \
          py::arg("per_tensor_k_scale"),                    \
          py::arg("per_tensor_v_scale"),                    \
          py::arg("k_out"),                                 \
          py::arg("v_out"),                                 \
          py::arg("return_kv"),                             \
          py::arg("use_shuffle_layout"),                    \
          py::arg("block_size"),                            \
          py::arg("x"),                                     \
          py::arg("rotary_dim") = 0);

#define FUSED_QKNORM_ROPE_CACHE_QUANT_PYBIND                    \
    m.def("fused_qk_norm_rope_cache_quant_shuffle",             \
          &aiter::fused_qk_norm_rope_cache_quant_shuffle);      \
    m.def("fused_qk_norm_rope_cache_pts_quant_shuffle",         \
          &aiter::fused_qk_norm_rope_cache_pts_quant_shuffle,   \
          py::arg("qkv"),                                       \
          py::arg("qw"),                                        \
          py::arg("kw"),                                        \
          py::arg("cos_sin"),                                   \
          py::arg("positions"),                                 \
          py::arg("num_tokens"),                                \
          py::arg("num_heads_q"),                               \
          py::arg("num_heads_k"),                               \
          py::arg("num_heads_v"),                               \
          py::arg("head_size"),                                 \
          py::arg("is_neox_style"),                             \
          py::arg("eps"),                                       \
          py::arg("q_out"),                                     \
          py::arg("k_cache"),                                   \
          py::arg("v_cache"),                                   \
          py::arg("slot_mapping"),                              \
          py::arg("per_tensor_k_scale"),                        \
          py::arg("per_tensor_v_scale"),                        \
          py::arg("k_out"),                                     \
          py::arg("v_out"),                                     \
          py::arg("return_kv"),                                 \
          py::arg("use_shuffle_layout"),                        \
          py::arg("block_size"),                                \
          py::arg("x"),                                         \
          py::arg("rotary_dim") = 0);                           \
    m.def("fused_qk_norm_rope_cache_block_quant_shuffle",       \
          &aiter::fused_qk_norm_rope_cache_block_quant_shuffle, \
          py::arg("qkv"),                                       \
          py::arg("num_heads_q"),                               \
          py::arg("num_heads_k"),                               \
          py::arg("num_heads_v"),                               \
          py::arg("head_dim"),                                  \
          py::arg("eps"),                                       \
          py::arg("q_weight"),                                  \
          py::arg("k_weight"),                                  \
          py::arg("cos_sin_cache"),                             \
          py::arg("is_neox"),                                   \
          py::arg("position_ids"),                              \
          py::arg("k_cache"),                                   \
          py::arg("v_cache"),                                   \
          py::arg("slot_mapping"),                              \
          py::arg("cu_q_len"),                                  \
          py::arg("kv_cache_dtype"),                            \
          py::arg("k_scale"),                                   \
          py::arg("v_scale"),                                   \
          py::arg("max_tokens_per_batch") = 0);                 \
    m.def("fused_qk_norm_rope_2way", &aiter::fused_qk_norm_rope_2way);

#define SMOOTHQUANT_PYBIND                        \
      m.def("smoothquant_fwd", &smoothquant_fwd); \
      m.def("moe_smoothquant_fwd", &moe_smoothquant_fwd);

#define HIPBSOLGEMM_PYBIND                                                           \
      m.def("hipb_create_extension", &hipb_create_extension, "create_extension");    \
      m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); \
      m.def("hipb_mm", &hipb_mm, "hipb_mm", py::arg("mat1"), py::arg("mat2"),        \
            py::arg("solution_index"), py::arg("bias") = std::nullopt,               \
            py::arg("out_dtype") = std::nullopt, py::arg("scaleA") = std::nullopt,   \
            py::arg("scaleB") = std::nullopt, py::arg("scaleOut") = std::nullopt,    \
            py::arg("scaleType") = std::nullopt);                                    \
      m.def("hipb_findallsols", &hipb_findallsols, "hipb_findallsols",               \
            py::arg("mat1"), py::arg("mat2"), py::arg("bias") = std::nullopt,        \
            py::arg("out_dtype") = std::nullopt, py::arg("scaleA") = std::nullopt,   \
            py::arg("scaleB") = std::nullopt, py::arg("scaleC") = std::nullopt,      \
            py::arg("scaleType") = std::nullopt);                                    \
      m.def("getHipblasltKernelName", &getHipblasltKernelName);

#define ROCSOLGEMM_PYBIND                                                            \
      m.def("rocb_create_extension", &rocb_create_extension, "create_extension");    \
      m.def("rocb_destroy_extension", &rocb_destroy_extension, "destroy_extension"); \
      m.def("rocb_mm", &RocSolIdxBlas, "mm");                                        \
      m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols");

#define AITER_ENUM_PYBIND                                \
    pybind11::enum_<QuantType>(m, "QuantType")           \
        .value("No", QuantType::No)                      \
        .value("per_Tensor", QuantType::per_Tensor)      \
        .value("per_Token", QuantType::per_Token)        \
        .value("per_1x32", QuantType::per_1x32)          \
        .value("per_1x128", QuantType::per_1x128)        \
        .value("per_128x128", QuantType::per_128x128)    \
        .export_values();                                \
    pybind11::enum_<ActivationType>(m, "ActivationType") \
        .value("No", ActivationType::No)                 \
        .value("Silu", ActivationType::Silu)             \
        .value("Gelu", ActivationType::Gelu)             \
        .export_values();                                \
    pybind11::implicitly_convertible<int, QuantType>();  \
    pybind11::implicitly_convertible<int, ActivationType>();

#define TOPK_PLAIN_PYBIND                         \
    m.def("topk_plain",                           \
          &topk_plain,                            \
          py::arg("values"),                      \
          py::arg("topk_ids"),                    \
          py::arg("topk_out"),                    \
          py::arg("topk"),                        \
          py::arg("largest")   = true,            \
          py::arg("rowStarts") = torch::Tensor(), \
          py::arg("rowEnds")   = torch::Tensor(), \
          py::arg("stride0")   = -1,              \
          py::arg("stride1")   = 1);

#define TOPK_TRANSFORM_PYBIND                         \
    m.def("fast_topk_interface",                      \
          &fast_topk_interface,                       \
          py::arg("score"),                           \
          py::arg("indices"),                         \
          py::arg("lengths"),                         \
          py::arg("row_starts_opt") = std::nullopt);  \
    m.def("fast_topk_transform_interface",            \
          &fast_topk_transform_interface,             \
          py::arg("score"),                           \
          py::arg("lengths"),                         \
          py::arg("dst_page_table"),                  \
          py::arg("src_page_table"),                  \
          py::arg("cu_seqlens_q"),                    \
          py::arg("row_starts_opt") = std::nullopt);  \
    m.def("fast_topk_transform_ragged_interface",     \
          &fast_topk_transform_ragged_interface,      \
          py::arg("score"),                           \
          py::arg("lengths"),                         \
          py::arg("topk_indices_ragged"),             \
          py::arg("topk_indices_offset"),             \
          py::arg("row_starts_opt") = std::nullopt);

#define MOE_C_PYBIND                                                                 \
      m.def("moe_c_moe_gemm_marlin_w8a8",                                                  \
      &moe_c_moe_gemm_marlin_w8a8,                                                         \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("a_scale"),                                                            \
      py::arg("b_scale"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
      py::arg("delta")                                                              \
      );                                                                              \
      m.def("moe_c_moe_gemm_marlin_w4a8",                                                  \
      &moe_c_moe_gemm_marlin_w4a8,                                                         \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("a_scale"),                                                            \
      py::arg("b_scale"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
      py::arg("delta")                                                              \
      );                                                                              \
      m.def("moe_c_moe_gemm_marlin_w8a8_fp8",                                                  \
      &moe_c_moe_gemm_marlin_w8a8_fp8,                                                         \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("a_scale"),                                                            \
      py::arg("b_scale"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
      py::arg("delta")                                                              \
      );                                                                              \
      m.def("moe_c_moe_gemm_marlin_w4a16",                                                  \
      &moe_c_moe_gemm_marlin_w4a16,                                                         \
      py::arg("input"),                                                              \
      py::arg("b_qweight"),                                                          \
      py::arg("output"),                                                             \
      py::arg("b_scale"),                                                            \
      py::arg("b_zeros"),                                                            \
      py::arg("topk_weights") ,                                                      \
      py::arg("sorted_token_ids"),                                                   \
      py::arg("expert_ids"),                                                         \
      py::arg("num_tokens_post_pad"),                                                \
      py::arg("top_k"),                                                              \
      py::arg("mode"),                                                               \
      py::arg("delta")                                                              \
      );                                                                               \
      m.def("moe_c_moe_w8a16_gemm_block_wise",                                                  \
        &moe_c_moe_w8a16_gemm_block_wise,                                                   \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                             \
        py::arg("topk_weights")  ,                                                        \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_w8a16_gemm_awq ==================== */                 \
  m.def("moe_c_moe_w8a16_gemm_awq",                                                         \
        &moe_c_moe_w8a16_gemm_awq,                                                          \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                             \
        py::arg("topk_weights") ,                                                         \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_wna16_gemm ==================== */                     \
  m.def("moe_c_moe_wna16_gemm",                                                             \
        &moe_c_moe_wna16_gemm,                                                              \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros")  ,                                                           \
        py::arg("topk_weights") ,                                                    \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_wna16_gemm_2 ==================== */                    \
  m.def("moe_c_moe_wna16_gemm_2",                                                           \
        &moe_c_moe_wna16_gemm_2,                                                            \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                        \
        py::arg("topk_weights"),                                                          \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
                                                                                          \
  /* ==================== moe_align_block_size ==================== */               \
  m.def("moe_c_moe_align_block_size",                                                       \
        &moe_c_moe_align_block_size,                                                        \
        py::arg("topk_ids"),                                                          \
        py::arg("num_experts"),                                                       \
        py::arg("block_size"),                                                        \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("experts_ids"),                                                       \
        py::arg("num_tokens_post_pad")                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_sum ==================== */                             \
  m.def("moe_c_moe_sum",                                                                    \
        &moe_c_moe_sum,                                                                     \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("topk_ids")                                                           \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_wna16_gemm_base ==================== */                \
  m.def("moe_c_moe_wna16_gemm_base",                                                        \
        &moe_c_moe_wna16_gemm_base,                                                         \
        py::arg("input"),                                                             \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros"),                                                          \
        py::arg("topk_weights"),                                                     \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== sgl_moe_align_block_size ==================== */           \
  m.def("moe_c_sgl_moe_align_block_size",                                                   \
        &moe_c_sgl_moe_align_block_size,                                                    \
        py::arg("topk_ids"),                                                          \
        py::arg("num_experts"),                                                       \
        py::arg("block_size"),                                                        \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("experts_ids"),                                                       \
        py::arg("num_tokens_post_pad")                                                \
  );                                                                                  \
                                                                                      \
\
                                                                                      \
  /* ==================== moe_w8a8_gemm_block_wise ==================== */           \
  m.def("moe_c_moe_w8a8_gemm_block_wise",                                                  \
        &moe_c_moe_w8a8_gemm_block_wise,                                                   \
        py::arg("input"),                                                             \
        py::arg("a_scales"),                                                          \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                           \
        py::arg("topk_weights"),                                                        \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_w8a8_gemm_block_wise_kernel2 ==================== */   \
  m.def("moe_c_moe_w8a8_gemm_block_wise_kernel2",                                          \
        &moe_c_moe_w8a8_gemm_block_wise_kernel2,                                           \
        py::arg("input"),                                                             \
        py::arg("a_scales"),                                                          \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                           \
        py::arg("topk_weights") ,                                                   \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_w8a8_gemm_block_wise_fp8 ==================== */       \
  m.def("moe_c_moe_w8a8_gemm_block_wise_fp8",                                              \
        &moe_c_moe_w8a8_gemm_block_wise_fp8,                                               \
        py::arg("input"),                                                             \
        py::arg("a_scales"),                                                          \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                        \
        py::arg("topk_weights"),                                                       \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );                                                                                  \
                                                                                      \
  /* ==================== moe_w8a8_gemm_block_wise_kernel2_fp8 ==================== */ \
  m.def("moe_c_moe_w8a8_gemm_block_wise_kernel2_fp8",                                      \
        &moe_c_moe_w8a8_gemm_block_wise_kernel2_fp8,                                       \
        py::arg("input"),                                                             \
        py::arg("a_scales"),                                                          \
        py::arg("output"),                                                            \
        py::arg("b_qweight"),                                                         \
        py::arg("b_scales"),                                                          \
        py::arg("b_qzeros") ,                                                       \
        py::arg("topk_weights"),                                                   \
        py::arg("sorted_token_ids"),                                                  \
        py::arg("expert_ids"),                                                        \
        py::arg("num_tokens_post_pad"),                                               \
        py::arg("group_size_n"),                                                      \
        py::arg("group_size_k"),                                                      \
        py::arg("top_k"),                                                             \
        py::arg("BLOCK_SIZE_M"),                                                      \
        py::arg("BLOCK_SIZE_N"),                                                      \
        py::arg("BLOCK_SIZE_K"),                                                      \
        py::arg("kloops"),                                                            \
        py::arg("nloops"),                                                            \
        py::arg("bit")                                                                \
  );    \
  m.def("moe_c_topk_softmax",                                                               \
        &moe_c_topk_softmax,                                                                \
        py::arg("topk_weights"),                                                      \
        py::arg("topk_indices"),                                                      \
        py::arg("token_expert_indices"),                                              \
        py::arg("gating_output")                                                      \
  );                                                                                  \    
/* ==================== silu_and_mul ==================== */                       \
  m.def("moe_c_silu_and_mul",                                                               \
        &moe_c_silu_and_mul,                                                                \
        py::arg("out"),                                                               \
        py::arg("input")                                                              \
  );