#pragma once // SPDX-License-Identifier: MIT #include #include "aiter_enum.h" void asm_fmoe_stage1(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, hidden_dim, dim] N,K torch::Tensor &down, // [expert, hidden_dim, dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // uint32_t top_k, std::optional scale_a = std::nullopt, std::optional scale_b = std::nullopt, std::optional zero_points = std::nullopt, std::optional mode = 0, std::optional solidx = 0, std::optional block_size = 16, std::optional persist_groups = 0 ); void asm_fmoe_stage2(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, hidden_dim, dim] N,K torch::Tensor &down, // [expert, hidden_dim, dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // uint32_t top_k, std::optional scale_a = std::nullopt, std::optional scale_b = std::nullopt, std::optional zero_points = std::nullopt, std::optional mode = 0, std::optional solidx = 0, std::optional block_size = 16, std::optional persist_groups = 0 ); void asm_fmoe_a8(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, hidden_dim, dim] N,K torch::Tensor &down, // [expert, hidden_dim, dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // uint32_t top_k, std::optional scale_a = std::nullopt, std::optional scale_b = std::nullopt, std::optional zero_points = std::nullopt, std::optional mode = 0, std::optional solidx = 0, std::optional out_type = 0, std::optional persist_groups = 0, std::optional use_shuffle = 0 ); std::vector asm_moe_get_solutions(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &topk_weights, // [tokens, topk] torch::Tensor &topk_ids, // [tokens, topk] std::optional use_int8_w8a16, // use int8 w8a16 quantization std::optional use_int4_w4a16, // use int4 w4a16 quantization std::optional use_int8_w8a8, // use int8 w8a8 quantization std::optional use_int4_w4a8, // use int4 w4a8 quantization std::optional use_fp8_w8a8, // use f8 w8a8 quantization std::optional per_channel_quant, // use channel quantization std::optional w1_zp, // [e, 2*n, k/group], gate(up) zero-point std::optional w2_zp, // [e, k, n/group], down zero-point std::optional w1_scale, // [e, 1, n], gate(up) scale or ... std::optional w2_scale, // [e, 1, k], down scale or ... std::optional a1_scale, // [m, 1], token scale std::optional a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m = 32, // moe partion size for tokens in m direction std::optional expert_mask = std::nullopt);