// SPDX-License-Identifier: MIT // ---- aiter_tensor_t / stream pybind infrastructure (ported from aiter-github) ---- #include "aiter_tensor.h" #include "aiter_stream.h" #include #include namespace py = pybind11; #ifndef AITER_SET_STREAM_PYBIND #define AITER_SET_STREAM_PYBIND \ m.def("_set_current_hip_stream", \ [](int64_t stream_ptr) { \ aiter::setCurrentHIPStream((hipStream_t)stream_ptr); \ }, \ pybind11::arg("stream_ptr")); #endif #ifndef AITER_CORE_PYBIND #define AITER_CORE_PYBIND \ pybind11::enum_(m, "QuantType") \ .value("No", QuantType::No) \ .value("per_Tensor", QuantType::per_Tensor) \ .value("per_Token", QuantType::per_Token) \ .value("per_1x32", QuantType::per_1x32) \ .value("per_1x128", QuantType::per_1x128) \ .value("per_128x128", QuantType::per_128x128) \ .value("per_256x128", QuantType::per_256x128) \ .value("per_1024x128", QuantType::per_1024x128) \ .export_values(); \ pybind11::enum_(m, "ActivationType") \ .value("No", ActivationType::No) \ .value("Silu", ActivationType::Silu) \ .value("Gelu", ActivationType::Gelu) \ .value("Swiglu", ActivationType::Swiglu) \ .export_values(); \ pybind11::implicitly_convertible(); \ pybind11::implicitly_convertible(); \ AITER_SET_STREAM_PYBIND \ pybind11::class_(m, "aiter_tensor_t") \ .def(pybind11::init<>()) \ .def(pybind11::init([](int64_t data_ptr, size_t numel, int ndim, \ const std::vector& shape, \ const std::vector& strides, \ int dtype, int device_id) { \ aiter_tensor_t at{}; \ at.ptr = (void*)data_ptr; \ at.numel_ = numel; \ at.ndim = ndim; \ for(int i = 0; i < ndim && i < 8; i++) { \ at.shape[i] = shape[i]; \ at.strides[i] = strides[i]; \ } \ at.dtype_ = (AiterDtype)dtype; \ at.device_id = device_id; \ return at; \ }), \ pybind11::arg("data_ptr"), \ pybind11::arg("numel"), \ pybind11::arg("ndim"), \ pybind11::arg("shape"), \ pybind11::arg("strides"), \ pybind11::arg("dtype"), \ pybind11::arg("device_id")) \ .def_readwrite("numel_", &aiter_tensor_t::numel_) \ .def_readwrite("ndim", &aiter_tensor_t::ndim) \ .def_readwrite("device_id", &aiter_tensor_t::device_id); #endif // Registers only aiter_tensor_t and _set_current_hip_stream — no enum registrations. // Use this in modules that already have QuantType/ActivationType (e.g. via module_aiter_enum). #ifndef AITER_TENSOR_PYBIND #define AITER_TENSOR_PYBIND \ AITER_SET_STREAM_PYBIND \ pybind11::class_(m, "aiter_tensor_t") \ .def(pybind11::init<>()) \ .def(pybind11::init([](int64_t data_ptr, size_t numel, int ndim, \ const std::vector& shape, \ const std::vector& strides, \ int dtype, int device_id) { \ aiter_tensor_t at{}; \ at.ptr = (void*)data_ptr; \ at.numel_ = numel; \ at.ndim = ndim; \ for(int i = 0; i < ndim && i < 8; i++) { \ at.shape[i] = shape[i]; \ at.strides[i] = strides[i]; \ } \ at.dtype_ = (AiterDtype)dtype; \ at.device_id = device_id; \ return at; \ }), \ pybind11::arg("data_ptr"), \ pybind11::arg("numel"), \ pybind11::arg("ndim"), \ pybind11::arg("shape"), \ pybind11::arg("strides"), \ pybind11::arg("dtype"), \ pybind11::arg("device_id")) \ .def_readwrite("numel_", &aiter_tensor_t::numel_) \ .def_readwrite("ndim", &aiter_tensor_t::ndim) \ .def_readwrite("device_id", &aiter_tensor_t::device_id); #endif // ---- end aiter_tensor_t / stream pybind infrastructure ---- #define ACTIVATION_PYBIND \ m.def("silu_and_mul", &aiter::silu_and_mul, "Activation function used in SwiGLU.", \ py::arg("out"), py::arg("input")); \ m.def("scaled_silu_and_mul", &aiter::scaled_silu_and_mul, "Activation function used in scaled SwiGLU.",\ py::arg("out"), py::arg("input"), py::arg("scale")); \ m.def("gelu_and_mul", &aiter::gelu_and_mul, "Activation function used in GELU.", \ py::arg("out"), py::arg("input")); \ m.def("gelu_tanh_and_mul", &aiter::gelu_tanh_and_mul, "Activation function used in GELU tanh.", \ py::arg("out"), py::arg("input")); #define AITER_OPERATOR_PYBIND \ m.def("add", &aiter_add, "apply for add with transpose and broadcast."); \ m.def("mul", &aiter_mul, "apply for mul with transpose and broadcast."); \ m.def("sub", &aiter_sub, "apply for sub with transpose and broadcast."); \ m.def("div", &aiter_div, "apply for div with transpose and broadcast."); \ m.def("add_", &aiter_add_, "apply for add_ with transpose and broadcast."); \ m.def("mul_", &aiter_mul_, "apply for mul_ with transpose and broadcast."); \ m.def("sub_", &aiter_sub_, "apply for sub_ with transpose and broadcast."); \ m.def("div_", &aiter_div_, "apply for div_ with transpose and broadcast."); #define AITER_UNARY_PYBIND \ m.def("sigmoid", &aiter_sigmoid, "apply for sigmoid."); \ m.def("tanh", &aiter_tanh, "apply for tanh."); #define ATTENTION_ASM_MLA_PYBIND \ m.def("mla_decode_stage1_asm_fwd", &mla_decode_stage1_asm_fwd, "mla_decode_stage1_asm_fwd", \ py::arg("Q"), \ py::arg("KV"), \ py::arg("qo_indptr"), \ py::arg("kv_indptr"), \ py::arg("kv_page_indices"), \ py::arg("kv_last_page_lens"), \ py::arg("max_seqlen_q"), \ py::arg("softmax_scale"), \ py::arg("splitData"), \ py::arg("splitLse")); \ m.def("mla_prefill_asm_fwd", &mla_prefill_asm_fwd, "mla_prefill_asm_fwd", \ py::arg("Q"), \ py::arg("KV"), \ py::arg("qo_indptr"), \ py::arg("kv_indptr"), \ py::arg("kv_page_indices"), \ py::arg("kv_last_page_lens"), \ py::arg("max_seqlen_q"), \ py::arg("softmax_scale"), \ py::arg("splitData"), \ py::arg("splitLse")); #define ATTENTION_ASM_PYBIND \ m.def("pa_fwd_asm", &pa_fwd, "pa_fwd", \ py::arg("Q"), \ py::arg("K"), \ py::arg("V"), \ py::arg("block_tables"), \ py::arg("context_lens"), \ py::arg("max_num_blocks"), \ py::arg("K_QScale") = std::nullopt, \ py::arg("V_QScale") = std::nullopt, \ py::arg("out_") = std::nullopt, \ py::arg("high_precision") = 1); #define ATTENTION_CK_PYBIND \ m.def("pa_fwd_naive", &pa_fwd_naive, "pa_fwd_naive", \ py::arg("Q"), \ py::arg("K"), \ py::arg("V"), \ py::arg("block_tables"), \ py::arg("context_lens"), \ py::arg("k_dequant_scales"), \ py::arg("v_dequant_scales"), \ py::arg("max_seq_len"), \ py::arg("num_kv_heads"), \ py::arg("scale_s"), \ py::arg("scale_k"), \ py::arg("scale_v"), \ py::arg("block_size"), \ py::arg("quant_algo"), \ py::arg("out_") = std::nullopt); #define ATTENTION_PYBIND \ m.def("paged_attention_rocm", &paged_attention, \ "paged_attention_rocm(Tensor! out, Tensor exp_sums," \ " Tensor max_logits, Tensor tmp_out," \ " Tensor query, Tensor key_cache," \ " Tensor value_cache, int num_kv_heads," \ " float scale, Tensor block_tables," \ " Tensor context_lens, int block_size," \ " int max_context_len," \ " Tensor? alibi_slopes," \ " str kv_cache_dtype," \ " float k_scale, float v_scale) -> ()"); #define ATTENTION_RAGGED_PYBIND \ m.def("paged_attention_ragged", &paged_attention_ragged, \ "paged_attention_ragged(Tensor! out, Tensor exp_sums," \ " Tensor max_logits, Tensor tmp_out," \ " Tensor query, Tensor key_cache," \ " Tensor value_cache, int num_kv_heads," \ " float scale, Tensor block_tables," \ " Tensor context_lens, int block_size," \ " int max_context_len," \ " Tensor? alibi_slopes," \ " str kv_cache_dtype," \ " float k_scale, float v_scale) -> ()"); #define BATCHED_GEMM_A8W8_PYBIND \ m.def("batched_gemm_a8w8", &batched_gemm_a8w8, "batched_gemm_a8w8", py::arg("XQ"), py::arg("WQ"), \ py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), \ py::arg("bias") = std::nullopt, py::arg("splitK") = 0); #define BATCHED_GEMM_A8W8_TUNE_PYBIND \ m.def("batched_gemm_a8w8_tune", &batched_gemm_a8w8_tune, "batched_gemm_a8w8_tune", py::arg("XQ"), py::arg("WQ"), \ py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0, \ py::arg("splitK") = 0); #define CACHE_PYBIND \ m.def("swap_blocks", &swap_blocks, \ "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); \ m.def("copy_blocks", ©_blocks, \ "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " \ "Tensor block_mapping) -> ()"); \ \ m.def("reshape_and_cache", &reshape_and_cache, \ "reshape_and_cache"); \ m.def("reshape_and_cache_flash", &reshape_and_cache_flash, \ "reshape_and_cache_flash(Tensor key, Tensor value," \ " Tensor! key_cache," \ " Tensor! value_cache," \ " Tensor slot_mapping," \ " str kv_cache_dtype," \ " float k_scale, float v_scale) -> ()"); \ m.def("reshape_and_cache_with_pertoken_quant", &reshape_and_cache_with_pertoken_quant, \ "reshape_and_cache_with_pertoken_quant(Tensor key, Tensor value," \ " Tensor! key_cache," \ " Tensor! value_cache," \ " Tensor! k_dequant_scales," \ " Tensor! v_dequant_scales," \ " Tensor slot_mapping) -> ()"); \ m.def("reshape_and_cache_with_block_quant", &reshape_and_cache_with_block_quant, \ "reshape_and_cache_with_block_quant(Tensor key, Tensor value," \ " Tensor! key_cache," \ " Tensor! value_cache," \ " Tensor! k_dequant_scales," \ " Tensor! v_dequant_scales," \ " Tensor slot_mapping," \ " const bool asm_layout) -> ()"); \ m.def("convert_fp8", &convert_fp8, \ "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " \ "str kv_cache_dtype) -> ()"); #define CUSTOM_ALL_REDUCE_PYBIND \ AITER_TENSOR_PYBIND \ m.def("init_custom_ar", \ &aiter::init_custom_ar, \ py::arg("meta_ptr"), \ py::arg("rank_data_ptr"), \ py::arg("rank_data_sz"), \ py::arg("ipc_handle_ptrs"), \ py::arg("offsets"), \ py::arg("rank"), \ py::arg("fully_connected")); \ m.def("all_reduce", \ &aiter::all_reduce, \ py::arg("_fa"), \ py::arg("inp"), \ py::arg("out"), \ py::arg("use_new"), \ py::arg("open_fp8_quant"), \ py::arg("reg_inp_ptr"), \ py::arg("reg_inp_bytes")); \ m.def("reduce_scatter", \ &aiter::reduce_scatter, \ py::arg("_fa"), \ py::arg("inp"), \ py::arg("out"), \ py::arg("reg_ptr"), \ py::arg("reg_bytes")); \ m.def("all_gather_reg", \ &aiter::all_gather_reg, \ py::arg("_fa"), \ py::arg("inp"), \ py::arg("out"), \ py::arg("dim")); \ m.def("all_gather_unreg", \ &aiter::all_gather_unreg, \ py::arg("_fa"), \ py::arg("inp"), \ py::arg("reg_buffer"), \ py::arg("out"), \ py::arg("reg_bytes"), \ py::arg("dim")); \ m.def("fused_allreduce_rmsnorm", \ &aiter::fused_allreduce_rmsnorm, \ py::arg("_fa"), \ py::arg("inp"), \ py::arg("res_inp"), \ py::arg("res_out"), \ py::arg("out"), \ py::arg("w"), \ py::arg("eps"), \ py::arg("reg_ptr"), \ py::arg("reg_bytes"), \ py::arg("use_1stage")); \ m.def("fused_allreduce_rmsnorm_quant", \ &aiter::fused_allreduce_rmsnorm_quant, \ py::arg("_fa"), \ py::arg("inp"), \ py::arg("res_inp"), \ py::arg("res_out"), \ py::arg("out"), \ py::arg("scale_out"), \ py::arg("w"), \ py::arg("eps"), \ py::arg("reg_ptr"), \ py::arg("reg_bytes"), \ py::arg("use_1stage")); \ m.def("fused_allreduce_rmsnorm_quant_per_group", \ &aiter::fused_allreduce_rmsnorm_quant_per_group, \ py::arg("_fa"), \ py::arg("inp"), \ py::arg("res_inp"), \ py::arg("res_out"), \ py::arg("out"), \ py::arg("scale_out"), \ py::arg("w"), \ py::arg("eps"), \ py::arg("group_size"), \ py::arg("reg_ptr"), \ py::arg("reg_bytes"), \ py::arg("use_1stage"), \ py::arg("bf16_out_ptr") = static_cast(0)); \ m.def("fused_qknorm_allreduce", \ &aiter::fused_qknorm_allreduce, \ py::arg("_fa"), \ py::arg("qkv_in"), \ py::arg("q_w"), \ py::arg("k_w"), \ py::arg("q_out"), \ py::arg("k_out"), \ py::arg("v_out"), \ py::arg("eps"), \ py::arg("reg_ptr"), \ py::arg("reg_bytes")); \ m.def("dispose", &aiter::dispose, py::arg("_fa")); \ m.def("meta_size", &aiter::meta_size); \ m.def("register_input_buffer", \ &aiter::register_input_buffer, \ py::arg("_fa"), \ py::arg("self_ptr"), \ py::arg("ipc_handle_ptrs"), \ py::arg("offsets")); \ m.def("register_output_buffer", \ &aiter::register_output_buffer, \ py::arg("_fa"), \ py::arg("self_ptr"), \ py::arg("ipc_handle_ptrs"), \ py::arg("offsets")); \ m.def("get_graph_buffer_count", &aiter::get_graph_buffer_count, py::arg("_fa")); \ m.def("get_graph_buffer_ipc_meta", \ &aiter::get_graph_buffer_ipc_meta, \ py::arg("_fa"), \ py::arg("handle_out"), \ py::arg("offset_out")); \ m.def("register_graph_buffers", \ &aiter::register_graph_buffers, \ py::arg("_fa"), \ py::arg("handle_ptrs"), \ py::arg("offset_ptrs")); \ m.def("allocate_meta_buffer", &aiter::allocate_meta_buffer, py::arg("size")); \ m.def("free_meta_buffer", &aiter::free_meta_buffer, py::arg("ptr")); \ m.def("get_meta_buffer_ipc_handle", \ &aiter::get_meta_buffer_ipc_handle, \ py::arg("inp_ptr"), \ py::arg("out_handle_ptr")); #define CUSTOM_PYBIND \ m.def("wvSpltK", &wvSpltK, "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," \ " int CuCount) -> ()"); \ m.def("LLMM1", &LLMM1, "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " \ "()"); #define GEMM_A8W8_ASM_PYBIND \ m.def("gemm_a8w8_asm", &gemm_a8w8_asm, \ "Asm gemm a8w8 , weight should be shuffle to layout(32,16)", \ py::arg("XQ"), py::arg("WQ"), \ py::arg("x_scale"), py::arg("w_scale"), \ py::arg("Out"), py::arg("bias"), \ py::arg("sub_m") = 128, py::arg("sub_n") = 128, \ py::arg("pad_a") = 0, py::arg("pad_b") = 0, \ py::arg("pad_c") = 0, py::arg("splitK") = 0); #define GEMM_A8W8_BLOCKSCALE_PYBIND \ m.def("gemm_a8w8_blockscale", &gemm_a8w8_blockscale, "fp8 blockscale gemm", py::arg("XQ"), py::arg("WQ"), \ py::arg("x_scale"), py::arg("w_scale"), py::arg("Out")); #define GEMM_A8W8_BLOCKSCALE_TUNE_PYBIND \ m.def("gemm_a8w8_blockscale_tune", &gemm_a8w8_blockscale_tune, "gemm_a8w8_blockscale_tune", py::arg("XQ"), py::arg("WQ"), \ py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0, \ py::arg("splitK") = 0); #define GEMM_A8W8_PYBIND \ m.def("gemm_a8w8", &gemm_a8w8, "gemm_a8w8", py::arg("XQ"), py::arg("WQ"), \ py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), \ py::arg("bias") = std::nullopt, py::arg("splitK") = 0); #define GEMM_A8W8_TUNE_PYBIND \ m.def("gemm_a8w8_tune", &gemm_a8w8_tune, "gemm_a8w8_tune", py::arg("XQ"), py::arg("WQ"), \ py::arg("x_scale"), py::arg("w_scale"), py::arg("Out"), py::arg("kernelId") = 0, \ py::arg("splitK") = 0); #define MHA_BWD_ASM_PYBIND \ m.def("fmha_v3_bwd", &aiter::torch_itfs::fmha_v3_bwd, \ py::arg("dout"), \ py::arg("q"), py::arg("k"), py::arg("v"), \ py::arg("out"), \ py::arg("softmax_lse"), \ py::arg("dropout_p"), \ py::arg("softmax_scale"), \ py::arg("is_causal"), \ py::arg("window_size_left"), \ py::arg("window_size_right"), \ py::arg("deterministic"), \ py::arg("is_v3_atomic_fp32"), \ py::arg("how_v3_bf16_cvt"), \ py::arg("dq") = std::nullopt, \ py::arg("dk") = std::nullopt, \ py::arg("dv") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt); #define MHA_VARLEN_BWD_ASM_PYBIND \ m.def("fmha_v3_varlen_bwd", &aiter::torch_itfs::fmha_v3_varlen_bwd, \ py::arg("dout"), \ py::arg("q"), py::arg("k"), py::arg("v"), \ py::arg("out"), \ py::arg("softmax_lse"), \ py::arg("cu_seqlens_q"), \ py::arg("cu_seqlens_k"), \ py::arg("max_seqlen_q"), \ py::arg("max_seqlen_k"), \ py::arg("dropout_p"), \ py::arg("softmax_scale"), \ py::arg("zero_tensors"), \ py::arg("is_causal"), \ py::arg("window_size_left"), \ py::arg("window_size_right"), \ py::arg("deterministic"), \ py::arg("is_v3_atomic_fp32"), \ py::arg("how_v3_bf16_cvt"), \ py::arg("dq") = std::nullopt, \ py::arg("dk") = std::nullopt, \ py::arg("dv") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt); #define MHA_BWD_PYBIND \ m.def("mha_bwd", &aiter::torch_itfs::mha_bwd, \ py::arg("dout"), \ py::arg("q"), py::arg("k"), py::arg("v"), \ py::arg("out"), \ py::arg("softmax_lse"), \ py::arg("dropout_p"), \ py::arg("softmax_scale"), \ py::arg("is_causal"), \ py::arg("window_size_left"), \ py::arg("window_size_right"), \ py::arg("deterministic"), \ py::arg("dq") = std::nullopt, \ py::arg("dk") = std::nullopt, \ py::arg("dv") = std::nullopt, \ py::arg("dbias") = std::nullopt, \ py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt); #define MHA_FWD_ASM_PYBIND \ m.def("fmha_v3_fwd", &aiter::torch_itfs::fmha_v3_fwd, \ py::arg("q"), py::arg("k"), py::arg("v"), \ py::arg("dropout_p"), \ py::arg("softmax_scale"), \ py::arg("is_causal"), \ py::arg("window_size_left"), \ py::arg("window_size_right"), \ py::arg("return_softmax_lse"), \ py::arg("return_dropout_randval"), \ py::arg("out") = std::nullopt, \ py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("gen") = std::nullopt); #define MHA_FWD_PYBIND \ m.def("mha_fwd", &aiter::torch_itfs::mha_fwd, \ py::arg("q"), py::arg("k"), py::arg("v"), \ py::arg("dropout_p"), \ py::arg("softmax_scale"), \ py::arg("is_causal"), \ py::arg("window_size_left"), \ py::arg("window_size_right"), \ py::arg("return_softmax_lse"), \ py::arg("return_dropout_randval"), \ py::arg("out") = std::nullopt, \ py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("gen") = std::nullopt); #define MHA_VARLEN_BWD_PYBIND \ m.def("mha_varlen_bwd", &aiter::torch_itfs::mha_varlen_bwd, \ py::arg("dout"), \ py::arg("q"), py::arg("k"), py::arg("v"), \ py::arg("out"), \ py::arg("softmax_lse"), \ py::arg("cu_seqlens_q"), \ py::arg("cu_seqlens_k"), \ py::arg("max_seqlen_q"), \ py::arg("max_seqlen_k"), \ py::arg("dropout_p"), \ py::arg("softmax_scale"), \ py::arg("zero_tensors"), \ py::arg("is_causal"), \ py::arg("window_size_left"), \ py::arg("window_size_right"), \ py::arg("deterministic"), \ py::arg("dq") = std::nullopt, \ py::arg("dk") = std::nullopt, \ py::arg("dv") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt); #define MHA_VARLEN_FWD_PYBIND \ m.def("mha_varlen_fwd", &aiter::torch_itfs::mha_varlen_fwd, \ py::arg("q"), py::arg("k"), py::arg("v"), \ py::arg("cu_seqlens_q"), \ py::arg("cu_seqlens_k"), \ py::arg("max_seqlen_q"), \ py::arg("max_seqlen_k"), \ py::arg("dropout_p"), \ py::arg("softmax_scale"), \ py::arg("logits_soft_cap"), \ py::arg("zero_tensors"), \ py::arg("is_causal"), \ py::arg("window_size_left"), \ py::arg("window_size_right"), \ py::arg("return_softmax_lse"), \ py::arg("return_dropout_randval"), \ py::arg("out") = std::nullopt, \ py::arg("block_table") = std::nullopt, \ py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("gen") = std::nullopt); #define MHA_BATCH_PREFILL_PYBIND \ m.def("mha_batch_prefill", &aiter::torch_itfs::mha_batch_prefill, \ py::arg("q"), py::arg("k"), py::arg("v"), \ py::arg("cu_seqlens_q"), \ py::arg("kv_indptr"), \ py::arg("kv_page_indices"), \ py::arg("max_seqlen_q"), \ py::arg("max_seqlen_k"), \ py::arg("dropout_p"), \ py::arg("softmax_scale"), \ py::arg("logits_soft_cap"), \ py::arg("zero_tensors"), \ py::arg("is_causal"), \ py::arg("window_size_left"), \ py::arg("window_size_right"), \ py::arg("return_softmax_lse"), \ py::arg("return_dropout_randval"), \ py::arg("out") = std::nullopt, \ py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("gen") = std::nullopt); #define MOE_CK_2STAGES_PYBIND \ m.def("ck_moe_stage1", &ck_moe_stage1, \ py::arg("hidden_states"), \ py::arg("w1"), \ py::arg("w2"), \ py::arg("sorted_token_ids"), \ py::arg("sorted_expert_ids"), \ py::arg("num_valid_ids"), \ py::arg("out"), \ py::arg("topk"), \ py::arg("w1_scale") = std::nullopt, \ py::arg("a1_scale") = std::nullopt, \ py::arg("block_m") = 32, \ py::arg("sorted_weights") = std::nullopt, \ py::arg("act_op") = 0); \ \ m.def("ck_moe_stage2", &ck_moe_stage2, \ py::arg("inter_states"), \ py::arg("w1"), \ py::arg("w2"), \ py::arg("sorted_token_ids"), \ py::arg("sorted_expert_ids"), \ py::arg("num_valid_ids"), \ py::arg("out"), \ py::arg("topk"), \ py::arg("w2_scale") = std::nullopt, \ py::arg("a2_scale") = std::nullopt, \ py::arg("block_m") = 32, \ py::arg("sorted_weights") = std::nullopt); \ #define MOE_ASM_2STAGES_PYBIND \ m.def("asm_fmoe_stage1", &asm_fmoe_stage1, \ py::arg("out"), \ py::arg("input"), \ py::arg("gate"), \ py::arg("down"), \ py::arg("sorted_token_ids"), \ py::arg("sorted_weights"), \ py::arg("sorted_expert_ids"), \ py::arg("num_valid_ids"), \ py::arg("top_k"), \ py::arg("scale_a") = std::nullopt, \ py::arg("scale_b") = std::nullopt, \ py::arg("zero_points") = std::nullopt, \ py::arg("mode") = 0, \ py::arg("solidx") = 0, \ py::arg("block_size") = 16, \ py::arg("persist_groups") = 0); \ \ m.def("asm_fmoe_stage2", &asm_fmoe_stage2, \ py::arg("out"), \ py::arg("input"), \ py::arg("gate"), \ py::arg("down"), \ py::arg("sorted_token_ids"), \ py::arg("sorted_weights"), \ py::arg("sorted_expert_ids"), \ py::arg("num_valid_ids"), \ py::arg("top_k"), \ py::arg("scale_a") = std::nullopt, \ py::arg("scale_b") = std::nullopt, \ py::arg("zero_points") = std::nullopt, \ py::arg("mode") = 0, \ py::arg("solidx") = 0, \ py::arg("block_size") = 16, \ py::arg("persist_groups") = 0); \ \ m.def("asm_fmoe_a8", &asm_fmoe_a8, \ py::arg("out"), \ py::arg("input"), \ py::arg("gate"), \ py::arg("down"), \ py::arg("sorted_token_ids"), \ py::arg("sorted_weights"), \ py::arg("sorted_expert_ids"), \ py::arg("num_valid_ids"), \ py::arg("top_k"), \ py::arg("scale_a") = std::nullopt, \ py::arg("scale_b") = std::nullopt, \ py::arg("zero_points") = std::nullopt, \ py::arg("mode") = 0, \ py::arg("solidx") = 0, \ py::arg("out_type") = 0, \ py::arg("persist_groups") = 0, \ py::arg("use_shuffle") = 0); \ \ m.def("asm_moe_get_solutions", &asm_moe_get_solutions, \ py::arg("hidden_states"), \ py::arg("w1"), \ py::arg("w2"), \ py::arg("topk_weights"), \ py::arg("topk_ids"), \ py::arg("use_int8_w8a16") = false, \ py::arg("use_int4_w4a16") = false, \ py::arg("use_int8_w8a8") = false, \ py::arg("use_int4_w4a8") = false, \ py::arg("use_fp8_w8a8") = false, \ py::arg("per_channel_quant") = false, \ py::arg("w1_zp") = std::nullopt, \ py::arg("w2_zp") = std::nullopt, \ py::arg("w1_scale") = std::nullopt, \ py::arg("w2_scale") = std::nullopt, \ py::arg("a1_scale") = std::nullopt, \ py::arg("a2_scale") = std::nullopt, \ py::arg("block_shape_n") = 0, \ py::arg("block_shape_k") = 0, \ py::arg("block_m") = 32, \ py::arg("expert_mask") = std::nullopt); \ #define AWQ_GEMM_ASM_PYBIND \ m.def("awq_gemm_asm", &awq_gemm_asm, \ py::arg("out"), \ py::arg("mat1"), \ py::arg("mat2"), \ py::arg("zero") = std::nullopt, \ py::arg("scalar") = std::nullopt ); \ m.def("awq_gemm_asm_tuning", &awq_gemm_asm_tuning, \ py::arg("out"), \ py::arg("mat1"), \ py::arg("mat2"), \ py::arg("zero") = std::nullopt, \ py::arg("scalar") = std::nullopt, \ py::arg("solidx") = 0, \ py::arg("jsonfile") = std::nullopt ); \ #define AWQ_DQ_ASM_PYBIND \ m.def("awq_dq_asm", &awq_dq_asm, \ py::arg("out"), \ py::arg("mat1"), \ py::arg("zero") = std::nullopt, \ py::arg("scalar") = std::nullopt ); \ #define MOE_CK_PYBIND \ m.def("ck_moe", &ck_moe, \ py::arg("hidden_states"), py::arg("w1"), py::arg("w2"), \ py::arg("topk_weights"), py::arg("topk_ids"), \ py::arg("use_int8_w8a16") = false, \ py::arg("use_int4_w4a16") = false, \ py::arg("use_int8_w8a8_block") = false, \ py::arg("use_int4_w4a8_block") = false, \ py::arg("w1_zp") = std::nullopt, py::arg("w2_zp") = std::nullopt, \ py::arg("w1_scale") = std::nullopt, py::arg("w2_scale") = std::nullopt, \ py::arg("a1_scale") = std::nullopt, py::arg("a2_scale") = std::nullopt, \ py::arg("block_shape_n") = 0, \ py::arg("block_shape_k") = 0, \ py::arg("block_m") = 32, \ py::arg("solution_id") = 0, \ py::arg("expert_mask") = std::nullopt); \ m.def("ck_shuffle_moe", &ck_shuffle_moe, \ py::arg("hidden_states"), py::arg("w1"), py::arg("w2"), \ py::arg("topk_weights"), py::arg("topk_ids"), \ py::arg("use_int8_w8a16") = false, \ py::arg("use_int4_w4a16") = false, \ py::arg("use_int8_w8a8_block") = false, \ py::arg("use_int4_w4a8_block") = false, \ py::arg("w1_zp") = std::nullopt, py::arg("w2_zp") = std::nullopt, \ py::arg("w1_scale") = std::nullopt, py::arg("w2_scale") = std::nullopt, \ py::arg("a1_scale") = std::nullopt, py::arg("a2_scale") = std::nullopt, \ py::arg("block_shape_n") = 0, \ py::arg("block_shape_k") = 0, \ py::arg("block_m") = 32, \ py::arg("solution_id") = 0, \ py::arg("expert_mask") = std::nullopt); \ m.def("ck_moe_get_solutions", &ck_moe_get_solutions, \ py::arg("hidden_states"), py::arg("w1"), py::arg("w2"), \ py::arg("topk_weights"), py::arg("topk_ids"), \ py::arg("use_int8_w8a16") = false, \ py::arg("use_int4_w4a16") = false, \ py::arg("use_int8_w8a8_block") = false, \ py::arg("use_int4_w4a8_block") = false, \ py::arg("w1_zp") = std::nullopt, py::arg("w2_zp") = std::nullopt, \ py::arg("w1_scale") = std::nullopt, py::arg("w2_scale") = std::nullopt, \ py::arg("a1_scale") = std::nullopt, py::arg("a2_scale") = std::nullopt, \ py::arg("block_shape_n") = 0, \ py::arg("block_shape_k") = 0, \ py::arg("block_m") = 32, \ py::arg("expert_mask") = std::nullopt); \ m.def("ck_moe_stage_1", &ck_moe_stage_1, \ py::arg("hidden_states"), \ py::arg("w1"), \ py::arg("w2"), \ py::arg("sorted_token_ids"), \ py::arg("sorted_expert_ids"), \ py::arg("tokens_positions_per_expert"), \ py::arg("num_valid_ids"), \ py::arg("out"), \ py::arg("topk"), \ py::arg("use_int8_w8a8_block") = false, \ py::arg("use_fp8_w8a8_block") = false, \ py::arg("w1_scale") = std::nullopt, \ py::arg("a1_scale") = std::nullopt, \ py::arg("block_shape_n") = 0, \ py::arg("block_shape_k") = 0, \ py::arg("block_m") = 32, \ py::arg("sorted_weights") = std::nullopt, \ py::arg("act_op") = 0); \ m.def("ck_moe_stage_2", &ck_moe_stage_2, \ py::arg("inter_states"), \ py::arg("w1"), \ py::arg("w2"), \ py::arg("sorted_token_ids"), \ py::arg("sorted_expert_ids"), \ py::arg("tokens_positions_per_expert"), \ py::arg("num_valid_ids"), \ py::arg("out"), \ py::arg("topk"), \ py::arg("use_int8_w8a8_block") = false, \ py::arg("use_fp8_w8a8_block") = false, \ py::arg("w2_scale") = std::nullopt, \ py::arg("a2_scale") = std::nullopt, \ py::arg("block_shape_n") = 0, \ py::arg("block_shape_k") = 0, \ py::arg("block_m") = 32, \ py::arg("sorted_weights") = std::nullopt); \ m.def("ck_moe_per_token_quant", &ck_moe_per_token_quant, \ py::arg("input"), \ py::arg("out_quant"), \ py::arg("out_scale")); \ #define MOE_UTILS_PYBIND \ m.def("topk_softmax", \ &aiter::topk_softmax, \ py::arg("topk_weights"), \ py::arg("topk_indices"), \ py::arg("token_expert_indices"), \ py::arg("gating_output"), \ py::arg("need_renorm"), \ "Apply topk softmax to the gating outputs."); \ m.def("grouped_topk", \ &grouped_topk, \ py::arg("gating_output"), \ py::arg("topk_weights"), \ py::arg("topk_ids"), \ py::arg("num_expert_group"), \ py::arg("topk_grp"), \ py::arg("need_renorm"), \ py::arg("is_softmax") = true, \ py::arg("routed_scaling_factor") = 1.0f, \ "Apply grouped topk softmax/sigmodd to the gating outputs."); \ m.def("biased_grouped_topk", \ &biased_grouped_topk, \ py::arg("gating_output"), \ py::arg("correction_bias"), \ py::arg("topk_weights"), \ py::arg("topk_ids"), \ py::arg("num_expert_group"), \ py::arg("topk_grp"), \ py::arg("need_renorm"), \ py::arg("routed_scaling_factor") = 1.0f, \ "Apply biased grouped topk softmax to the gating outputs."); \ m.def("moe_sum", &aiter::moe_sum, "moe_sum(Tensor! input, Tensor output) -> ()"); \ m.def("moe_fused_gate", \ &moe_fused_gate, \ py::arg("input"), \ py::arg("bias"), \ py::arg("topk_weights"), \ py::arg("topk_ids"), \ py::arg("num_expert_group"), \ py::arg("topk_group"), \ py::arg("topk"), \ py::arg("num_fused_shared_experts"), \ py::arg("routed_scaling_factor") = 1.0, \ "Apply biased grouped topk softmax to the gating outputs."); \ m.def("moe_align_block_size", &moe_align_block_size, \ "moe_align_block_size(Tensor topk_ids, int num_experts," \ " int block_size, Tensor! sorted_token_ids," \ " Tensor! experts_ids," \ " Tensor! num_tokens_post_pad) -> ()"); \ m.def("sgl_moe_align_block_size", &sgl_moe_align_block_size, \ "sgl_moe_align_block_size(Tensor topk_ids, int num_experts," \ " int block_size, Tensor! sorted_token_ids," \ " Tensor! experts_ids," \ " Tensor! num_tokens_post_pad) -> ()"); \ #define MOE_OP_PYBIND \ m.def("fmoe", &fmoe); \ m.def("fmoe_int8_g1u0", &fmoe_int8_g1u0, \ py::arg("out"), py::arg("input"), \ py::arg("gate"), py::arg("down"), \ py::arg("sorted_token_ids"), py::arg("sorted_weights"), \ py::arg("sorted_expert_ids"), py::arg("num_valid_ids"), \ py::arg("topk"), py::arg("input_scale"), \ py::arg("fc1_scale"), py::arg("fc2_scale"), \ py::arg("fc2_smooth_scale") = std::nullopt, \ py::arg("activation") = ActivationType::Silu); \ m.def("fmoe_g1u1", &fmoe_g1u1, \ py::arg("out"), py::arg("input"), \ py::arg("gate"), py::arg("down"), \ py::arg("sorted_token_ids"), py::arg("sorted_weights"), \ py::arg("sorted_expert_ids"), py::arg("num_valid_ids"), \ py::arg("topk"), py::arg("input_scale"), \ py::arg("fc1_scale"), py::arg("fc2_scale"), \ py::arg("fc2_smooth_scale") = std::nullopt, \ py::arg("activation") = ActivationType::Silu); \ m.def("fmoe_g1u1_tkw1", &fmoe_g1u1_tkw1, \ py::arg("out"), py::arg("input"), \ py::arg("gate"), py::arg("down"), \ py::arg("sorted_token_ids"), py::arg("sorted_weights"), \ py::arg("sorted_expert_ids"), py::arg("num_valid_ids"), \ py::arg("topk"), py::arg("input_scale"), \ py::arg("fc1_scale"), py::arg("fc2_scale"), \ py::arg("fc2_smooth_scale") = std::nullopt, \ py::arg("activation") = ActivationType::Silu); \ m.def("fmoe_int8_g1u0_a16", &fmoe_int8_g1u0_a16); \ m.def("fmoe_g1u1_a16", &fmoe_g1u1_a16); \ m.def("fmoe_fp8_blockscale_g1u1", &fmoe_fp8_blockscale_g1u1, \ py::arg("out"), py::arg("input"), \ py::arg("gate"), py::arg("down"), \ py::arg("sorted_token_ids"), py::arg("sorted_weights"), \ py::arg("sorted_expert_ids"), py::arg("num_valid_ids"), \ py::arg("topk"), \ py::arg("input_scale"), \ py::arg("fc1_scale"), py::arg("fc2_scale"), \ py::arg("fc_scale_blkn") = 128, py::arg("fc_scale_blkk") = 128, \ py::arg("fc2_smooth_scale") = std::nullopt, \ py::arg("activation") = ActivationType::Silu); \ m.def("moe_stage1_g1u1", &moe_stage1_g1u1, \ py::arg("input"), \ py::arg("w1"), py::arg("w2"), \ py::arg("sorted_token_ids"), \ py::arg("sorted_expert_ids"), py::arg("num_valid_ids"), \ py::arg("out"), \ py::arg("inter_dim"), \ py::arg("kernelName"), \ py::arg("block_m"), \ py::arg("ksplit") = 0, \ py::arg("activation") = ActivationType::Silu, \ py::arg("quant_type") = QuantType::No, \ py::arg("a1_scale") = std::nullopt, \ py::arg("w1_scale") = std::nullopt, \ py::arg("sorted_weights") = std::nullopt); \ #define MOE_SUM_PYBIND \ m.def("asm_moe_sum", &asm_moe_sum, "asm_moe_sum(Tensor! input, Tensor output, Tensor sorted_ids) -> ()"); \ #define MOE_SORTING_PYBIND \ m.def("moe_sorting_fwd", &moe_sorting_fwd, \ py::arg("topk_ids"), py::arg("topk_weights"), \ py::arg("sorted_token_ids"), py::arg("sorted_weights"), \ py::arg("sorted_expert_ids"), py::arg("tokens_positions_per_expert"), \ py::arg("num_valid_ids"), py::arg("moe_buf"), py::arg("num_experts"), \ py::arg("unit_size"), py::arg("local_expert_mask") = std::nullopt); #define NORM_PYBIND \ m.def("layernorm2d_fwd", &layernorm2d, \ py::arg("input"), py::arg("weight"), py::arg("bias"), \ py::arg("epsilon"), py::arg("x_bias") = std::nullopt); \ m.def("layernorm2d_fwd_with_add", &layernorm2d_with_add, \ py::arg("out"), py::arg("input"), \ py::arg("residual_in"), py::arg("residual_out"), \ py::arg("weight"), py::arg("bias"), \ py::arg("epsilon"), py::arg("x_bias") = std::nullopt); \ m.def("layernorm2d_fwd_with_smoothquant", &layernorm2d_with_smoothquant, \ py::arg("out"), py::arg("input"), \ py::arg("xscale"), py::arg("yscale"), \ py::arg("weight"), py::arg("bias"), \ py::arg("epsilon"), py::arg("x_bias") = std::nullopt); \ m.def("layernorm2d_fwd_with_add_smoothquant", &layernorm2d_with_add_smoothquant, \ py::arg("out"), py::arg("input"), \ py::arg("residual_in"), py::arg("residual_out"), \ py::arg("xscale"), py::arg("yscale"), \ py::arg("weight"), py::arg("bias"), \ py::arg("epsilon"), py::arg("x_bias") = std::nullopt); \ m.def("layernorm2d_fwd_with_dynamicquant", &layernorm2d_with_dynamicquant, \ py::arg("out"), py::arg("input"), \ py::arg("yscale"), py::arg("weight"), py::arg("bias"), \ py::arg("epsilon"), py::arg("x_bias") = std::nullopt); \ m.def("layernorm2d_fwd_with_add_dynamicquant", &layernorm2d_with_add_dynamicquant, \ py::arg("out"), py::arg("input"), \ py::arg("residual_in"), py::arg("residual_out"), \ py::arg("yscale"), py::arg("weight"), py::arg("bias"), \ py::arg("epsilon"), py::arg("x_bias") = std::nullopt); // m.def("layernorm2d_with_add_asm", &layernorm2d_with_add_asm); \ // m.def("layernorm2d_with_add_smoothquant_asm", &layernorm2d_with_add_smoothquant_asm); #define POS_ENCODING_PYBIND \ m.def("rotary_embedding_fwd", &rotary_embedding, "rotary_embedding"); \ m.def("batched_rotary_embedding", &batched_rotary_embedding, "batched_rotary_embedding"); #define QUANT_PYBIND \ m.def("static_per_tensor_quant", &aiter::static_per_tensor_quant); \ m.def("dynamic_per_tensor_quant", &aiter::dynamic_per_tensor_quant); \ m.def("dynamic_per_token_scaled_quant", \ &aiter::dynamic_per_token_scaled_quant, \ py::arg("out"), \ py::arg("input"), \ py::arg("scales"), \ py::arg("scale_ub") = std::nullopt, \ py::arg("shuffle_scale") = false, \ py::arg("num_rows") = std::nullopt, \ py::arg("num_rows_factor") = 1); \ m.def("dynamic_per_group_scaled_quant_fp4", \ &aiter::dynamic_per_group_scaled_quant_fp4, \ py::arg("out"), \ py::arg("input"), \ py::arg("scales"), \ py::arg("group_size") = 32, \ py::arg("shuffle_scale") = true, \ py::arg("num_rows") = std::nullopt, \ py::arg("num_rows_factor") = 1); \ m.def("smooth_per_token_scaled_quant", \ &aiter::smooth_per_token_scaled_quant, \ py::arg("out"), \ py::arg("input"), \ py::arg("scales"), \ py::arg("smooth_scale"), \ py::arg("smooth_scale_map") = std::nullopt, \ py::arg("shuffle_scale") = false, \ py::arg("num_rows") = std::nullopt, \ py::arg("num_rows_factor") = 1); \ m.def("partial_transpose", \ &aiter::partial_transpose, \ py::arg("out"), \ py::arg("input"), \ py::arg("num_rows")); \ m.def("moe_swiglu_dynamic_quant", \ &aiter::moe_swiglu_dynamic_quant, \ py::arg("scatter_tokens"), \ py::arg("smooth"), \ py::arg("experts_tokens_count"), \ py::arg("experts_tokens_start"), \ py::arg("output"), \ py::arg("scales"), \ py::arg("beta") = 1.0f); #define RMSNORM_PYBIND \ m.def("rms_norm_cu", \ &rms_norm, \ "Apply Root Mean Square (RMS) Normalization to the input tensor."); \ m.def( \ "fused_add_rms_norm_cu", &fused_add_rms_norm, "In-place fused Add and RMS Normalization"); \ m.def("rmsnorm2d_fwd", \ &rmsnorm2d, \ py::arg("input"), \ py::arg("weight"), \ py::arg("epsilon")); \ m.def("rmsnorm2d_fwd_with_add", \ &rmsnorm2d_with_add, \ py::arg("out"), \ py::arg("input"), \ py::arg("residual_in"), \ py::arg("residual_out"), \ py::arg("weight"), \ py::arg("epsilon")); \ m.def("rmsnorm2d_fwd_with_smoothquant", \ &rmsnorm2d_with_smoothquant, \ py::arg("out"), \ py::arg("input"), \ py::arg("xscale"), \ py::arg("yscale"), \ py::arg("weight"), \ py::arg("epsilon")); \ m.def("rmsnorm2d_fwd_with_add_smoothquant", \ &rmsnorm2d_with_add_smoothquant, \ py::arg("out"), \ py::arg("input"), \ py::arg("residual_in"), \ py::arg("residual_out"), \ py::arg("xscale"), \ py::arg("yscale"), \ py::arg("weight"), \ py::arg("epsilon"), \ py::arg("out_before_quant") = std::nullopt); \ m.def("rmsnorm2d_fwd_with_dynamicquant", \ &rmsnorm2d_with_dynamicquant, \ py::arg("out"), \ py::arg("input"), \ py::arg("yscale"), \ py::arg("weight"), \ py::arg("epsilon")); \ m.def("rmsnorm2d_fwd_with_add_dynamicquant", \ &rmsnorm2d_with_add_dynamicquant, \ py::arg("out"), \ py::arg("input"), \ py::arg("residual_in"), \ py::arg("residual_out"), \ py::arg("yscale"), \ py::arg("weight"), \ py::arg("epsilon")); \ m.def("head_rms_norm", \ &head_rms_norm, \ py::arg("input"), \ py::arg("weight"), \ py::arg("epsilon"), \ py::arg("norm_head_dim")); #define ROPE_GENERAL_FWD_PYBIND \ m.def("rope_fwd_impl", &rope_fwd_impl); \ m.def("rope_2c_fwd_impl", &rope_2c_fwd_impl); \ m.def("rope_cached_fwd_impl", &rope_cached_fwd_impl); \ m.def("rope_cached_2c_fwd_impl", &rope_cached_2c_fwd_impl); \ m.def("rope_thd_fwd_impl", &rope_thd_fwd_impl); \ m.def("rope_2d_fwd_impl", &rope_2d_fwd_impl); #define ROPE_GENERAL_BWD_PYBIND \ m.def("rope_bwd_impl", &rope_bwd_impl); \ m.def("rope_2c_bwd_impl", &rope_2c_bwd_impl); \ m.def("rope_cached_bwd_impl", &rope_cached_bwd_impl); \ m.def("rope_cached_2c_bwd_impl", &rope_cached_2c_bwd_impl); \ m.def("rope_thd_bwd_impl", &rope_thd_bwd_impl); \ m.def("rope_2d_bwd_impl", &rope_2d_bwd_impl); #define ROPE_POS_FWD_PYBIND \ m.def("rope_cached_positions_fwd_impl", &rope_cached_positions_fwd_impl); \ m.def("rope_cached_positions_2c_fwd_impl", &rope_cached_positions_2c_fwd_impl); \ m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl); \ m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl); #define FUSED_QKNORM_MROPE_CACHE_QUANT_PYBIND \ m.def("fused_qk_norm_mrope_3d_cache_pts_quant_shuffle", \ &fused_qk_norm_mrope_3d_cache_pts_quant_shuffle, \ py::arg("qkv"), \ py::arg("qw"), \ py::arg("kw"), \ py::arg("cos_sin"), \ py::arg("positions"), \ py::arg("num_tokens"), \ py::arg("num_heads_q"), \ py::arg("num_heads_k"), \ py::arg("num_heads_v"), \ py::arg("head_size"), \ py::arg("is_neox_style"), \ py::arg("mrope_section_"), \ py::arg("is_interleaved"), \ py::arg("eps"), \ py::arg("q_out"), \ py::arg("k_cache"), \ py::arg("v_cache"), \ py::arg("slot_mapping"), \ py::arg("per_tensor_k_scale"), \ py::arg("per_tensor_v_scale"), \ py::arg("k_out"), \ py::arg("v_out"), \ py::arg("return_kv"), \ py::arg("use_shuffle_layout"), \ py::arg("block_size"), \ py::arg("x"), \ py::arg("rotary_dim") = 0); #define FUSED_QKNORM_ROPE_CACHE_QUANT_PYBIND \ m.def("fused_qk_norm_rope_cache_quant_shuffle", \ &aiter::fused_qk_norm_rope_cache_quant_shuffle); \ m.def("fused_qk_norm_rope_cache_pts_quant_shuffle", \ &aiter::fused_qk_norm_rope_cache_pts_quant_shuffle, \ py::arg("qkv"), \ py::arg("qw"), \ py::arg("kw"), \ py::arg("cos_sin"), \ py::arg("positions"), \ py::arg("num_tokens"), \ py::arg("num_heads_q"), \ py::arg("num_heads_k"), \ py::arg("num_heads_v"), \ py::arg("head_size"), \ py::arg("is_neox_style"), \ py::arg("eps"), \ py::arg("q_out"), \ py::arg("k_cache"), \ py::arg("v_cache"), \ py::arg("slot_mapping"), \ py::arg("per_tensor_k_scale"), \ py::arg("per_tensor_v_scale"), \ py::arg("k_out"), \ py::arg("v_out"), \ py::arg("return_kv"), \ py::arg("use_shuffle_layout"), \ py::arg("block_size"), \ py::arg("x"), \ py::arg("rotary_dim") = 0); \ m.def("fused_qk_norm_rope_cache_block_quant_shuffle", \ &aiter::fused_qk_norm_rope_cache_block_quant_shuffle, \ py::arg("qkv"), \ py::arg("num_heads_q"), \ py::arg("num_heads_k"), \ py::arg("num_heads_v"), \ py::arg("head_dim"), \ py::arg("eps"), \ py::arg("q_weight"), \ py::arg("k_weight"), \ py::arg("cos_sin_cache"), \ py::arg("is_neox"), \ py::arg("position_ids"), \ py::arg("k_cache"), \ py::arg("v_cache"), \ py::arg("slot_mapping"), \ py::arg("cu_q_len"), \ py::arg("kv_cache_dtype"), \ py::arg("k_scale"), \ py::arg("v_scale"), \ py::arg("max_tokens_per_batch") = 0); \ m.def("fused_qk_norm_rope_2way", &aiter::fused_qk_norm_rope_2way); #define SMOOTHQUANT_PYBIND \ m.def("smoothquant_fwd", &smoothquant_fwd); \ m.def("moe_smoothquant_fwd", &moe_smoothquant_fwd); #define HIPBSOLGEMM_PYBIND \ m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); \ m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); \ m.def("hipb_mm", &hipb_mm, "hipb_mm", py::arg("mat1"), py::arg("mat2"), \ py::arg("solution_index"), py::arg("bias") = std::nullopt, \ py::arg("out_dtype") = std::nullopt, py::arg("scaleA") = std::nullopt, \ py::arg("scaleB") = std::nullopt, py::arg("scaleOut") = std::nullopt, \ py::arg("scaleType") = std::nullopt); \ m.def("hipb_findallsols", &hipb_findallsols, "hipb_findallsols", \ py::arg("mat1"), py::arg("mat2"), py::arg("bias") = std::nullopt, \ py::arg("out_dtype") = std::nullopt, py::arg("scaleA") = std::nullopt, \ py::arg("scaleB") = std::nullopt, py::arg("scaleC") = std::nullopt, \ py::arg("scaleType") = std::nullopt); \ m.def("getHipblasltKernelName", &getHipblasltKernelName); #define ROCSOLGEMM_PYBIND \ m.def("rocb_create_extension", &rocb_create_extension, "create_extension"); \ m.def("rocb_destroy_extension", &rocb_destroy_extension, "destroy_extension"); \ m.def("rocb_mm", &RocSolIdxBlas, "mm"); \ m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols"); #define AITER_ENUM_PYBIND \ pybind11::enum_(m, "QuantType") \ .value("No", QuantType::No) \ .value("per_Tensor", QuantType::per_Tensor) \ .value("per_Token", QuantType::per_Token) \ .value("per_1x32", QuantType::per_1x32) \ .value("per_1x128", QuantType::per_1x128) \ .value("per_128x128", QuantType::per_128x128) \ .export_values(); \ pybind11::enum_(m, "ActivationType") \ .value("No", ActivationType::No) \ .value("Silu", ActivationType::Silu) \ .value("Gelu", ActivationType::Gelu) \ .export_values(); \ pybind11::implicitly_convertible(); \ pybind11::implicitly_convertible(); #define TOPK_PLAIN_PYBIND \ m.def("topk_plain", \ &topk_plain, \ py::arg("values"), \ py::arg("topk_ids"), \ py::arg("topk_out"), \ py::arg("topk"), \ py::arg("largest") = true, \ py::arg("rowStarts") = torch::Tensor(), \ py::arg("rowEnds") = torch::Tensor(), \ py::arg("stride0") = -1, \ py::arg("stride1") = 1); #define TOPK_TRANSFORM_PYBIND \ m.def("fast_topk_interface", \ &fast_topk_interface, \ py::arg("score"), \ py::arg("indices"), \ py::arg("lengths"), \ py::arg("row_starts_opt") = std::nullopt); \ m.def("fast_topk_transform_interface", \ &fast_topk_transform_interface, \ py::arg("score"), \ py::arg("lengths"), \ py::arg("dst_page_table"), \ py::arg("src_page_table"), \ py::arg("cu_seqlens_q"), \ py::arg("row_starts_opt") = std::nullopt); \ m.def("fast_topk_transform_ragged_interface", \ &fast_topk_transform_ragged_interface, \ py::arg("score"), \ py::arg("lengths"), \ py::arg("topk_indices_ragged"), \ py::arg("topk_indices_offset"), \ py::arg("row_starts_opt") = std::nullopt); #define MOE_C_PYBIND \ m.def("moe_c_moe_gemm_marlin_w8a8", \ &moe_c_moe_gemm_marlin_w8a8, \ py::arg("input"), \ py::arg("b_qweight"), \ py::arg("output"), \ py::arg("a_scale"), \ py::arg("b_scale"), \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("mode"), \ py::arg("delta"), \ py::arg("size_m") \ ); \ m.def("moe_c_moe_gemm_marlin_w8a8_tensorwise", \ &moe_c_moe_gemm_marlin_w8a8_tensorwise, \ py::arg("input"), \ py::arg("b_qweight"), \ py::arg("output"), \ py::arg("a_scale"), \ py::arg("b_scale"), \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("mode"), \ py::arg("delta"), \ py::arg("size_m") \ ); \ m.def("moe_c_moe_gemm_marlin_w4a8", \ &moe_c_moe_gemm_marlin_w4a8, \ py::arg("input"), \ py::arg("b_qweight"), \ py::arg("output"), \ py::arg("a_scale"), \ py::arg("b_scale"), \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("mode"), \ py::arg("delta"), \ py::arg("size_m") \ ); \ m.def("moe_c_moe_gemm_marlin_w8a8_fp8", \ &moe_c_moe_gemm_marlin_w8a8_fp8, \ py::arg("input"), \ py::arg("b_qweight"), \ py::arg("output"), \ py::arg("a_scale"), \ py::arg("b_scale"), \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("mode"), \ py::arg("delta"), \ py::arg("size_m") \ ); \ m.def("moe_c_moe_gemm_marlin_w8a8_fp8_tensorwise", \ &moe_c_moe_gemm_marlin_w8a8_fp8_tensorwise, \ py::arg("input"), \ py::arg("b_qweight"), \ py::arg("output"), \ py::arg("a_scale"), \ py::arg("b_scale"), \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("mode"), \ py::arg("delta"), \ py::arg("size_m") \ ); \ m.def("moe_c_moe_gemm_marlin_w4a16", \ &moe_c_moe_gemm_marlin_w4a16, \ py::arg("input"), \ py::arg("b_qweight"), \ py::arg("output"), \ py::arg("b_scale"), \ py::arg("b_zeros"), \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("mode"), \ py::arg("delta") \ ); \ m.def("moe_c_moe_gemm_marlin_w8a16", \ &moe_c_moe_gemm_marlin_w8a16, \ py::arg("input"), \ py::arg("b_qweight"), \ py::arg("output"), \ py::arg("b_scale"), \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("mode"), \ py::arg("delta") \ ); \ m.def("moe_c_moe_w8a16_gemm_block_wise", \ &moe_c_moe_w8a16_gemm_block_wise, \ py::arg("input"), \ py::arg("output"), \ py::arg("b_qweight"), \ py::arg("b_scales"), \ py::arg("b_qzeros") , \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("group_size_n"), \ py::arg("group_size_k"), \ py::arg("top_k"), \ py::arg("BLOCK_SIZE_M"), \ py::arg("BLOCK_SIZE_N"), \ py::arg("BLOCK_SIZE_K"), \ py::arg("bit") \ ); \ \ /* ==================== moe_w8a16_gemm_awq ==================== */ \ m.def("moe_c_moe_w8a16_gemm_awq", \ &moe_c_moe_w8a16_gemm_awq, \ py::arg("input"), \ py::arg("output"), \ py::arg("b_qweight"), \ py::arg("b_scales"), \ py::arg("b_qzeros") , \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("BLOCK_SIZE_M"), \ py::arg("BLOCK_SIZE_N"), \ py::arg("BLOCK_SIZE_K"), \ py::arg("bit") \ ); \ \ /* ==================== moe_wna16_gemm ==================== */ \ m.def("moe_c_moe_wna16_gemm", \ &moe_c_moe_wna16_gemm, \ py::arg("input"), \ py::arg("output"), \ py::arg("b_qweight"), \ py::arg("b_scales"), \ py::arg("b_qzeros") , \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("BLOCK_SIZE_M"), \ py::arg("BLOCK_SIZE_N"), \ py::arg("BLOCK_SIZE_K"), \ py::arg("kloops"), \ py::arg("nloops"), \ py::arg("bit") \ ); \ \ /* ==================== moe_wna16_gemm_2 ==================== */ \ m.def("moe_c_moe_wna16_gemm_2", \ &moe_c_moe_wna16_gemm_2, \ py::arg("input"), \ py::arg("output"), \ py::arg("b_qweight"), \ py::arg("b_scales"), \ py::arg("b_qzeros") , \ py::arg("topk_weights"), \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("BLOCK_SIZE_M"), \ py::arg("BLOCK_SIZE_N"), \ py::arg("BLOCK_SIZE_K"), \ py::arg("kloops"), \ py::arg("nloops"), \ py::arg("bit") \ ); \ \ \ /* ==================== moe_align_block_size ==================== */ \ m.def("moe_c_moe_align_block_size", \ &moe_c_moe_align_block_size, \ py::arg("topk_ids"), \ py::arg("num_experts"), \ py::arg("block_size"), \ py::arg("sorted_token_ids"), \ py::arg("experts_ids"), \ py::arg("num_tokens_post_pad") \ ); \ \ \ /* ==================== moe_wna16_gemm_base ==================== */ \ m.def("moe_c_moe_wna16_gemm_base", \ &moe_c_moe_wna16_gemm_base, \ py::arg("input"), \ py::arg("output"), \ py::arg("b_qweight"), \ py::arg("b_scales"), \ py::arg("b_qzeros"), \ py::arg("topk_weights"), \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("top_k"), \ py::arg("BLOCK_SIZE_M"), \ py::arg("BLOCK_SIZE_N"), \ py::arg("BLOCK_SIZE_K"), \ py::arg("bit") \ ); \ \ /* ==================== sgl_moe_align_block_size ==================== */ \ m.def("moe_c_sgl_moe_align_block_size", \ &moe_c_sgl_moe_align_block_size, \ py::arg("topk_ids"), \ py::arg("num_experts"), \ py::arg("block_size"), \ py::arg("sorted_token_ids"), \ py::arg("experts_ids"), \ py::arg("num_tokens_post_pad") \ ); \ \ \ \ /* ==================== moe_w8a8_gemm_block_wise ==================== */ \ m.def("moe_c_moe_w8a8_gemm_block_wise", \ &moe_c_moe_w8a8_gemm_block_wise, \ py::arg("input"), \ py::arg("a_scales"), \ py::arg("output"), \ py::arg("b_qweight"), \ py::arg("b_scales"), \ py::arg("b_qzeros") , \ py::arg("topk_weights"), \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("group_size_n"), \ py::arg("group_size_k"), \ py::arg("top_k"), \ py::arg("BLOCK_SIZE_M"), \ py::arg("BLOCK_SIZE_N"), \ py::arg("BLOCK_SIZE_K"), \ py::arg("kloops"), \ py::arg("nloops"), \ py::arg("bit") \ ); \ \ /* ==================== moe_w8a8_gemm_block_wise_kernel2 ==================== */ \ m.def("moe_c_moe_w8a8_gemm_block_wise_kernel2", \ &moe_c_moe_w8a8_gemm_block_wise_kernel2, \ py::arg("input"), \ py::arg("a_scales"), \ py::arg("output"), \ py::arg("b_qweight"), \ py::arg("b_scales"), \ py::arg("b_qzeros") , \ py::arg("topk_weights") , \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("group_size_n"), \ py::arg("group_size_k"), \ py::arg("top_k"), \ py::arg("BLOCK_SIZE_M"), \ py::arg("BLOCK_SIZE_N"), \ py::arg("BLOCK_SIZE_K"), \ py::arg("kloops"), \ py::arg("nloops"), \ py::arg("bit") \ ); \ \ /* ==================== moe_w8a8_gemm_block_wise_fp8 ==================== */ \ m.def("moe_c_moe_w8a8_gemm_block_wise_fp8", \ &moe_c_moe_w8a8_gemm_block_wise_fp8, \ py::arg("input"), \ py::arg("a_scales"), \ py::arg("output"), \ py::arg("b_qweight"), \ py::arg("b_scales"), \ py::arg("b_qzeros") , \ py::arg("topk_weights"), \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("group_size_n"), \ py::arg("group_size_k"), \ py::arg("top_k"), \ py::arg("BLOCK_SIZE_M"), \ py::arg("BLOCK_SIZE_N"), \ py::arg("BLOCK_SIZE_K"), \ py::arg("kloops"), \ py::arg("nloops"), \ py::arg("bit") \ ); \ \ /* ==================== moe_w8a8_gemm_block_wise_kernel2_fp8 ==================== */ \ m.def("moe_c_moe_w8a8_gemm_block_wise_kernel2_fp8", \ &moe_c_moe_w8a8_gemm_block_wise_kernel2_fp8, \ py::arg("input"), \ py::arg("a_scales"), \ py::arg("output"), \ py::arg("b_qweight"), \ py::arg("b_scales"), \ py::arg("b_qzeros") , \ py::arg("topk_weights"), \ py::arg("sorted_token_ids"), \ py::arg("expert_ids"), \ py::arg("num_tokens_post_pad"), \ py::arg("group_size_n"), \ py::arg("group_size_k"), \ py::arg("top_k"), \ py::arg("BLOCK_SIZE_M"), \ py::arg("BLOCK_SIZE_N"), \ py::arg("BLOCK_SIZE_K"), \ py::arg("kloops"), \ py::arg("nloops"), \ py::arg("bit") \ ); \ m.def("moe_c_topk_softmax", \ &moe_c_topk_softmax, \ py::arg("topk_weights"), \ py::arg("topk_indices"), \ py::arg("token_expert_indices"), \ py::arg("gating_output") \ ); \ /* ==================== silu_and_mul ==================== */ \ m.def("moe_c_silu_and_mul",\ &moe_c_silu_and_mul,\ py::arg("out"),\ py::arg("input"),\ py::arg("rows_per_block") = 1,\ py::arg("vec_size") = 2);\ m.def("moe_c_moe_sum_opt_v2",\ &moe_c_moe_sum_opt_v2,\ py::arg("input"),\ py::arg("output"),\ py::arg("routed_scaling_factor") = 1.0); #define MHC_PYBIND \ m.def("mhc_pre_gemm_sqrsum", \ &aiter::mhc_pre_gemm_sqrsum, \ "mhc_pre_gemm_sqrsum", \ py::arg("out"), \ py::arg("sqrsum"), \ py::arg("x"), \ py::arg("fn"), \ py::arg("tile_k") = 128, \ py::arg("use_tf32") = false); \ m.def("mhc_pre_gemm_sqrsum_stage1_m128", \ &aiter::mhc_pre_gemm_sqrsum_stage1_m128, \ "mhc_pre_gemm_sqrsum_stage1_m128", \ py::arg("out"), \ py::arg("sqrsum"), \ py::arg("x"), \ py::arg("fn"), \ py::arg("use_tf32") = false); \ m.def("mhc_pre_reduce_splitk", \ &aiter::mhc_pre_reduce_splitk, \ "mhc_pre_reduce_splitk", \ py::arg("out_red"), \ py::arg("sqrsum_red"), \ py::arg("out"), \ py::arg("sqrsum")); \ m.def("mhc_pre_big_fuse", \ &aiter::mhc_pre_big_fuse, \ "mhc_pre_big_fuse", \ py::arg("post_mix"), \ py::arg("comb_mix"), \ py::arg("layer_input"), \ py::arg("gemm_out_mul"), \ py::arg("gemm_out_sqrsum"), \ py::arg("hc_scale"), \ py::arg("hc_base"), \ py::arg("residual"), \ py::arg("rms_eps") = 1e-6, \ py::arg("hc_pre_eps") = 1e-6, \ py::arg("hc_sinkhorn_eps") = 1e-6, \ py::arg("hc_post_mult_value") = 1.0, \ py::arg("sinkhorn_repeat") = 20); \ m.def("mhc_pre_big_fuse_tlstyle", \ &aiter::mhc_pre_big_fuse_tlstyle, \ "mhc_pre_big_fuse_tlstyle", \ py::arg("post_mix"), \ py::arg("comb_mix"), \ py::arg("layer_input"), \ py::arg("gemm_out_mul"), \ py::arg("gemm_out_sqrsum"), \ py::arg("hc_scale"), \ py::arg("hc_base"), \ py::arg("residual"), \ py::arg("rms_eps") = 1e-6, \ py::arg("hc_pre_eps") = 1e-6, \ py::arg("hc_sinkhorn_eps") = 1e-6, \ py::arg("hc_post_mult_value") = 1.0, \ py::arg("sinkhorn_repeat") = 20); \ m.def("mhc_post", \ &aiter::mhc_post, \ "mhc_post", \ py::arg("out"), \ py::arg("x"), \ py::arg("residual"), \ py::arg("post_layer_mix"), \ py::arg("comb_res_mix"));