// SPDX-License-Identifier: MIT #include #include #include #include #include #include "aiter_hip_common.h" #include "moe_op.h" //#include "py_itfs_common.h" #include #include #include class Moe1ModeLookup { public: Moe1ModeLookup() { ModeMap = { {1, 10004}, {2, 10000}, {4, 10007}, {8, 10000}, {12, 10000},{16, 10000}, {24, 10000}, {32, 10000}, {48, 10000}, {64, 10000},{96, 10000}, {128, 10007}, {160, 10000}, {256, 10000}, {320, 10000}, {512, 10000}, {768, 10000}, {1024, 10008}, {1536, 10008}, {2048, 10008}, {2560, 10008}, {2816, 11006}, {3072, 11006}, {4096, 11006}, {4608, 11006}, {4864, 11006}, {5120, 11006}, {5504, 11006}, {5760, 11006}, {6144, 12003}, {7168, 12003}, {8192, 12003}, {12288, 12003}, {16384, 12003}, {24576, 12003}, {32768, 12003} }; } uint32_t getMode(int m) { if (m < 1) throw std::out_of_range("m must be >= 1"); if (m > 32768) m = 32768; auto it = ModeMap.lower_bound(m); // Find the first key >= m // if it > m or end(), return the previous one if (it == ModeMap.end() || it->first > m) { if (it == ModeMap.begin()) throw std::logic_error("No valid MODE found"); --it; } return it->second; } private: std::map ModeMap; }; class Moe2ModeLookup { public: Moe2ModeLookup() { ModeMap = { {1, 20006}, {2, 20006}, {4, 20005}, {8, 20005}, {12, 20005},{16, 20005}, {24, 20005}, {32, 20006}, {48, 20005}, {64, 20005},{96, 20005}, {128, 20005}, {160, 20005}, {256, 20005}, {320, 20006}, {512, 20005}, {768, 20006}, {1024, 20006}, {1536, 20006}, {2048, 20006}, {2560, 20006}, {2816, 21004}, {3072, 21004}, {4096, 21004}, {4608, 21004}, {4864, 21004}, {5120, 21004}, {5504, 21004}, {5760, 21004}, {6144, 22002}, {7168, 22002}, {8192, 22002}, {12288, 22002}, {16384, 22002}, {24576, 22002}, {32768, 22002} }; ModeMap128k = { {1, 20101}, {2, 20101}, {4, 20006}, {8, 20006}, {12, 20006},{16, 20006}, {24, 20006}, {32, 20006}, {48, 20006}, {64, 20006},{96, 20006}, {128, 20100}, {160, 20006}, {256, 20006}, {320, 20006}, {512, 20006}, {768, 20006}, {1024, 20006}, {1536, 20006}, {2048, 20006}, {2560, 20006}, {2816, 21100}, {3072, 21100}, {4096, 21100}, {4608, 21101}, {4864, 21101}, {5120, 21100}, {5504, 21100}, {5760, 21100}, {6144, 22100}, {7168, 22100}, {8192, 22100}, {12288, 22100}, {16384, 22100}, {24576, 22100}, {32768, 22100} }; } uint32_t getMode(int m, int k=0) { if (m < 1) throw std::out_of_range("m must be >= 1"); if (m > 32768) m = 32768; // if it > m or end(), return the previous one if (k == 128) { auto it = ModeMap128k.lower_bound(m); // Find the first key >= m if (it == ModeMap128k.end() || it->first > m) { if (it == ModeMap128k.begin()) throw std::logic_error("No valid MODE found"); --it; } return it->second; } else { auto it = ModeMap.lower_bound(m); // Find the first key >= m if (it == ModeMap.end() || it->first > m) { if (it == ModeMap.begin()) throw std::logic_error("No valid MODE found"); --it; } return it->second; } } private: std::map ModeMap; std::map ModeMap128k; }; #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) struct __attribute__((packed)) KernelArgs { uint32_t gemm_count; uint32_t internalArgs; uint32_t internalArgs1; uint32_t numWorkGroups; void* argsPtr; }; struct __attribute__((packed)) GroupedGemmArgs { uint32_t m; //!< size m uint32_t n; //!< size n uint32_t batch; //!< size batch uint32_t k; //!< size k void* d; //!< The d matrix input pointer. void* c; //!< The c matrix input pointer. void* a; //!< The a matrix input pointer. void* b; //!< The b matrix input pointer. uint32_t strideD1; //!< The d leading dimension. uint32_t strideD2; //!< The d batch stride uint32_t strideC1; //!< The c leading dimension. uint32_t strideC2; //!< The c batch stride uint32_t strideA1; //!< The a leading dimension. uint32_t strideA2; //!< The a batch stride uint32_t strideB1; //!< The b leading dimension. uint32_t strideB2; //!< The b batch stride float alpha; //!< The alpha value. float beta; //!< The beta value. void* sorted_token_ids; void* sorted_weights; void* sorted_expert_ids; uint32_t num_valid_tokens; uint32_t top_k; void* scale_a; void* scale_b; void* zero_points; void* num_valid_ids; }; struct __attribute__((packed)) HipFunctionArgs { uint32_t align_size; uint32_t internalArgs; uint32_t internalArgs1; uint32_t numWorkGroups; uint32_t m; //!< size m uint32_t n; //!< size n uint32_t k; //!< size k void* d; //!< The d matrix input pointer. void* a; //!< The a matrix input pointer. void* b; //!< The b matrix input pointer. float alpha; //!< The alpha value. float beta; //!< The beta value. void* sorted_token_ids; void* sorted_weights; void* sorted_expert_ids; uint32_t num_valid_tokens; uint32_t experts_num; // num of experts uint32_t top_k; float topk_rcip; void* scale_a; void* scale_b; void* zero_points; void* num_valid_ids; uint32_t out_type; uint32_t persist_groups; uint32_t groups_nums; }; void printHipFunctionArgs(const HipFunctionArgs& args) { std::cout << "align_size: " << args.align_size << "\n"; std::cout << "internalArgs: " << args.internalArgs << "\n"; std::cout << "internalArgs1: " << args.internalArgs1 << "\n"; std::cout << "numWorkGroups0: " << args.numWorkGroups << "\n"; std::cout << "m: " << args.m << "\n"; std::cout << "n: " << args.n << "\n"; std::cout << "k: " << args.k << "\n"; std::cout << "d: " << static_cast(args.d) << "\n"; std::cout << "a: " << static_cast(args.a) << "\n"; std::cout << "b: " << static_cast(args.b) << "\n"; std::cout << "alpha: " << args.alpha << "\n"; std::cout << "beta: " << args.beta << "\n"; std::cout << "sorted_token_ids: " << static_cast(args.sorted_token_ids) << "\n"; std::cout << "sorted_weights: " << static_cast(args.sorted_weights) << "\n"; std::cout << "sorted_expert_ids: " << static_cast(args.sorted_expert_ids) << "\n"; std::cout << "num_valid_tokens: " << args.num_valid_tokens << "\n"; std::cout << "experts_num: " << args.experts_num << "\n"; std::cout << "top_k: " << args.top_k << "\n"; std::cout << "topk_rcip: " << args.topk_rcip << "\n"; std::cout << "scale_a: " << static_cast(args.scale_a) << "\n"; std::cout << "scale_b: " << static_cast(args.scale_b) << "\n"; std::cout << "zero_points: " << static_cast(args.zero_points) << "\n"; std::cout << "num_valid_ids: " << static_cast(args.num_valid_ids) << "\n"; std::cout << "out_type: " << args.out_type << "\n"; std::cout << "persist_groups: " << args.persist_groups << "\n"; std::cout << "groups_nums: " << args.groups_nums << "\n"; } class FMoeKernelA8 { private: hipModule_t module; hipFunction_t kernel_func; public: FMoeKernelA8(const char *name, const char *hsaco) { const char *AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); std::cout << "[aiter] hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() << " GetFunction: " << name; HIP_CALL(hipModuleLoad(&module, (std::string(AITER_ASM_DIR) + hsaco).c_str())); HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); std::cout << " Success" << std::endl; }; template void launch_kernel(const std::vector& Config, torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &w1, // [expert, inter_dim, dim] N,K torch::Tensor &w2, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor &num_valid_ids, // uint32_t top_k, int OutType, int PersistGroups, int BlockSizeM, std::optional scale_a = std::nullopt, std::optional scale_b = std::nullopt, std::optional zero_points = std::nullopt ) { int size_m, size_n, size_k; if constexpr (firstStage) { size_m = w1.size(1); } else { size_m = w2.size(1); } size_n = Config[1]; size_k = input.size(1); HipFunctionArgs hipFunctionArgs; hipFunctionArgs.align_size = BlockSizeM; hipFunctionArgs.internalArgs = 0x00200001; hipFunctionArgs.internalArgs1 = 1; uint32_t chosen_experts; if constexpr (firstStage) { chosen_experts = std::min(input.size(0) * top_k, sorted_token_ids.size(0) / Config[1]); } else { chosen_experts = std::min(input.size(0), sorted_token_ids.size(0) / Config[1]); } chosen_experts = chosen_experts * ((BlockSizeM - 1)/Config[1] + 1); hipFunctionArgs.numWorkGroups = DIVIDE(size_m, Config[0]); hipFunctionArgs.groups_nums = chosen_experts * DIVIDE(size_m, Config[0]); hipFunctionArgs.m = size_m; hipFunctionArgs.n = size_n; hipFunctionArgs.k = size_k; hipFunctionArgs.d = out.data_ptr(); if constexpr (firstStage) { hipFunctionArgs.a = w1.data_ptr(); } else { hipFunctionArgs.a = w2.data_ptr(); } hipFunctionArgs.b = input.data_ptr(); hipFunctionArgs.alpha = 1; hipFunctionArgs.beta = 0; hipFunctionArgs.sorted_token_ids = sorted_token_ids.data_ptr(); hipFunctionArgs.sorted_weights = sorted_weights.data_ptr(); hipFunctionArgs.sorted_expert_ids = sorted_expert_ids.data_ptr(); hipFunctionArgs.num_valid_tokens = input.size(0); hipFunctionArgs.experts_num = w1.size(0); hipFunctionArgs.top_k = top_k; hipFunctionArgs.topk_rcip = 1 / float(top_k); if(scale_a.has_value() && scale_a.value().has_storage()){ hipFunctionArgs.scale_b = scale_a.value().data_ptr(); } else{ hipFunctionArgs.scale_b = input.data_ptr(); } if(scale_b.has_value()&& scale_b.value().has_storage()){ hipFunctionArgs.scale_a = scale_b.value().data_ptr(); } else{ hipFunctionArgs.scale_a = input.data_ptr(); } if(zero_points.has_value()&& zero_points.value().has_storage()){ hipFunctionArgs.zero_points = zero_points.value().data_ptr(); } else{ hipFunctionArgs.zero_points = input.data_ptr(); } hipFunctionArgs.num_valid_ids = num_valid_ids.data_ptr(); hipFunctionArgs.out_type = OutType; hipFunctionArgs.persist_groups = PersistGroups; size_t arg_size = sizeof(hipFunctionArgs); void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &hipFunctionArgs, HIP_LAUNCH_PARAM_BUFFER_SIZE, &arg_size, HIP_LAUNCH_PARAM_END}; int bdx = Config[2]; int gdx = hipFunctionArgs.groups_nums; if (PersistGroups != 0) { gdx = PersistGroups; } int gdy = 1; int gdz = 1; //printHipFunctionArgs(hipFunctionArgs); //std::cout << "bdx " << bdx << "\n"; //std::cout << "gdx " << gdx << "\n"; //std::cout << "gdy " << gdy << "\n"; //std::cout << "gdz " << gdz << "\n"; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); HIP_CALL(hipModuleLaunchKernel(kernel_func, gdx, gdy, gdz, bdx, 1, 1, 0, stream, nullptr, (void **)&config)); }; }; void asm_fmoe_a8(torch::Tensor &out, // [token_cnt, dim] torch::Tensor &input, // [token_cnt, dim] M,K torch::Tensor &gate, // [expert, inter_dim, dim] N,K torch::Tensor &down, // [expert, dim, inter_dim] torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] torch::Tensor &sorted_weights, // [max_num_tokens_padded] torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] torch::Tensor & num_valid_ids, // uint32_t top_k, std::optional scale_a = std::nullopt, std::optional scale_b = std::nullopt, std::optional zero_points = std::nullopt, std::optional mode = 0, std::optional solidx = 0, std::optional out_type = 0, std::optional persist_groups = 0, std::optional use_shuffle = 0 ) { struct FMoeKernelA8Config { std::string name; std::string co_name; uint32_t MT0; uint32_t MT1; uint32_t bdx; }; torch::Tensor ScaleA; torch::Tensor ScaleB; torch::Tensor ZeroPoints; int OutType = 0; int Mode = 0; int PersistGroups = 0; int Shuffle = 0; if (scale_a.has_value()) { ScaleA = scale_a.value(); } if (scale_b.has_value()) { ScaleB = scale_b.value(); } if (zero_points.has_value()) { ZeroPoints = zero_points.value(); } if (mode.has_value()) { Mode = mode.value(); } if (out_type.has_value()) { OutType = out_type.value(); } if (use_shuffle.has_value()) { Shuffle = use_shuffle.value(); } if (persist_groups.has_value()) { PersistGroups = persist_groups.value(); } FMoeKernelA8 *impl_ptr = nullptr; static uint32_t first_stage_topk = 1; static uint32_t BlockSizeM = 0; static std::unordered_map> impl_ptr_map; std::vector Config = {32, 32, 512}; bool first_stage = true; std::unordered_map *config_map = nullptr; if (Mode == 0 || Mode == 4) { first_stage =true; first_stage_topk = top_k; static std::unordered_map moe1_kernel_int8_configs = { {10000, {"MT32x16x256_SN_K1_PGR4_WG16_16_2_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT32x16x256_SN_K1_SB0_SWMNK2_2_1_moe1.co", 32, 16, 384}}, {10001, {"MT32x16x256_SN_K1_PGR5_WG16_16_2_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT32x16x256_SN_K1_SB3_SWMNK2_2_1_moe1.co", 32, 16, 384}}, {10002, {"MT64x16x256_SN_K1_PGR3_WG16_16_3_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT64x16x256_SN_K1_SB3_SWMNK4_2_1_moe1.co", 64, 16, 512}}, {10003, {"MT32x16x256_SN_K1_PGR4_SB3_SWMNK2_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT32x16x256_T_PGR4_SB3_SWMNK2_1_1_moe1.co", 32, 16, 384}}, {10004, {"MT32x16x256_SN_K1_PGR5_SB3_SWMNK2_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT32x16x256_T_PGR5_SB3_SWMNK2_1_1_moe1.co", 32, 16, 384}}, {10005, {"MT32x16x256_SN_K1_PGR6_SB3_SWMNK2_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT32x16x256_T_PGR6_SB3_SWMNK2_1_1_moe1.co", 32, 16, 384}}, {10006, {"MT64x16x256_SN_K1_PGR4_SB3_SWMNK4_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT64x16x256_T_PGR4_SB3_SWMNK4_1_1_moe1.co", 64, 16, 512}}, {10007, {"MT64x16x256_SN_K1_PGR5_SB3_SWMNK4_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT64x16x256_T_PGR5_SB3_SWMNK4_1_1_moe1.co", 64, 16, 512}}, {10008, {"MT128x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x16x256_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10009, {"MT128x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x16x256_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10010, {"MT128x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x16x256_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10011, {"MT256x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x16x256_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {10012, {"MT256x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x16x256_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {10013, {"MT256x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x16x256_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {11000, {"MT32x32x256_SN_K1_PGR4_WG16_16_2_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT32x32x256_SN_K1_SB0_SWMNK2_2_1_moe1.co", 32, 32, 512}}, {11001, {"MT64x32x256_SN_K1_PGR2_WG16_16_3_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT64x32x256_SN_K1_SB0_SWMNK4_2_1_moe1.co", 64, 32, 768}}, {11002, {"MT128x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x32x256_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11003, {"MT256x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x32x256_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {11004, {"MT128x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x32x256_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11005, {"MT256x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x32x256_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {11006, {"MT128x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x32x256_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11007, {"MT256x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x32x256_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {12000, {"MT128x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x64x256_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 128, 64, 768}}, {12001, {"MT256x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x64x256_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {12002, {"MT128x64x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x64x256_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 128, 64, 768}}, {12003, {"MT256x64x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x64x256_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {12004, {"MT128x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x64x256_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 64, 768}}, {12005, {"MT256x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x64x256_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {13000, {"MT128x128x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x128x256_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 128, 768}}, {13001, {"MT256x128x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w8a8/stage1/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x128x256_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 128, 768}}}; config_map = &moe1_kernel_int8_configs; } else if (Mode == 1 || Mode == 5) { first_stage =false; static std::unordered_map moe2_kernel_int8_configs = { {20000, {"MT128x16x128_SN_K1_SB3_SWMNK4_1_1_moe2", "w8a8/stage2/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x16x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 128, 16, 256}}, {20001, {"MT256x16x128_SN_K1_SB3_SWMNK4_1_1_moe2", "w8a8/stage2/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x16x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 16, 256}}, {21000, {"MT128x32x128_SN_K1_SB3_SWMNK4_1_1_moe2", "w8a8/stage2/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x32x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 128, 32, 256}}, {21001, {"MT256x32x128_SN_K1_SB3_SWMNK4_1_1_moe2", "w8a8/stage2/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x32x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 32, 256}}, {22000, {"MT128x64x128_SN_K1_SB3_SWMNK4_1_1_moe2", "w8a8/stage2/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x64x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 128, 64, 256}}, {22001, {"MT256x64x128_SN_K1_SB3_SWMNK4_1_1_moe2", "w8a8/stage2/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT256x64x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 64, 256}}, {23000, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_1", "w8a8/stage2/Cijk_Alik_Bljk_I8HS_MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_1.co", 256, 128, 256}}, {23001, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2", "w8a8/stage2/Cijk_Alik_Bljk_I8HS_MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 128, 256}}, {23002, {"MT128x128x128_SN_K1_SB3_SWMNK4_1_1_moe2", "w8a8/stage2/Cijk_Alik_Bljk_I8HS_BH_SABV_UserArgs_MT128x128x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 128, 128, 256}}, {20100, {"MT1024x16x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT1024x16x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 768}}, {20101, {"MT2048x16x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT2048x16x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 2048, 16, 768}}, {20102, {"MT3584x16x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT3584x16x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 16, 768}}, {21100, {"MT1024x32x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT1024x32x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 32, 768}}, {21101, {"MT2048x32x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT2048x32x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 2048, 32, 768}}, {21102, {"MT3584x32x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT3584x32x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 32, 768}}, {22100, {"MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 64, 768}}, {22101, {"MT2048x64x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT2048x64x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 2048, 64, 768}}, {22102, {"MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_1", "w8a8/stage2/128k/MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_1.co", 1024, 64, 768}}, {22103, {"MT2048x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_1", "w8a8/stage2/128k/MT2048x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_1.co", 2048, 64, 768}}, {23100, {"MT1024x128x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT1024x128x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 128, 768}}, {23101, {"MT2048x128x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w8a8/stage2/128k/MT2048x128x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 2048, 128, 768}}}; config_map = &moe2_kernel_int8_configs; } else if (Mode == 2 || Mode == 6) { first_stage =true; first_stage_topk = top_k; static std::unordered_map moe1_kernel_w8a8_block_configs = { {10000, {"MT128x16x128_SN_K1_PGR3_WG16_16_3_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x16x128_SN_K1_SB0_SWMNK8_1_1_moe1_ScaleTolds.co", 128, 16, 768}}, {10001, {"MT128x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10002, {"MT128x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10003, {"MT128x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10004, {"MT256x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {10005, {"MT256x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {10006, {"MT256x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {10007, {"MT64x16x256_SN_K1_PGR4_SB3_SWMNK4_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT64x16x256_SN_K1_PGR4_SB3_SWMNK4_1_1_moe1.co", 64, 16, 512}}, {10008, {"MT32x16x256_SN_K1_PGR4_SB3_SWMNK2_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT32x16x256_SN_K1_PGR4_SB3_SWMNK2_1_1_moe1.co", 32, 16, 384}}, {11000, {"MT32x32x256_SN_K1_PGR4_WG16_16_2_moe1", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT32x32x256_SN_K1_SB0_SWMNK2_2_1_moe1.co", 32, 32, 512}}, {11001, {"MT64x32x256_SN_K1_PGR2_WG16_16_3_moe1", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT64x32x256_SN_K1_SB0_SWMNK4_2_1_moe1.co", 64, 32, 768}}, {11002, {"MT32x32x256_SN_K1_PGR3_WG16_16_2_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT32x32x256_SN_K1_SB0_SWMNK2_2_1_moe1_ScaleTolds.co", 32, 32, 512}}, {11003, {"MT64x32x256_SN_K1_PGR2_WG16_16_3_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT64x32x256_SN_K1_SB0_SWMNK4_2_1_moe1_ScaleTolds.co", 64, 32, 768}}, {11004, {"MT64x32x128_SN_K1_PGR4_WG16_16_3_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT64x32x128_SN_K1_SB0_SWMNK4_2_1_moe1_ScaleTolds.co", 64, 32, 768}}, {11005, {"MT128x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11006, {"MT128x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11007, {"MT128x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11008, {"MT256x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {11009, {"MT256x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {11010, {"MT256x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {12000, {"MT128x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1.co", 128, 64, 768}}, {12001, {"MT128x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 64, 768}}, {12002, {"MT256x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {12003, {"MT256x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {12004, {"MT256x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {12005, {"MT256x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {12006, {"MT256x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {13000, {"MT128x128x256_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x128x256_SB3_SWMNK8_1_1_moe1.co", 128, 128, 768}}, {13001, {"MT256x128x256_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x128x256_SB3_SWMNK8_1_1_moe1.co", 256, 128, 768}}, {14000, {"MT128x256x256_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT128x256x256_SB3_SWMNK8_1_1_moe1.co", 128, 256, 768}}, {14001, {"MT256x256x256_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w8a8_block/stage1/Cijk_Alik_Bljk_I8HS_MT256x256x256_SB3_SWMNK8_1_1_moe1.co", 256, 256, 768}}, {19999, {"kenrel_name", "w8a8_block/stage1/kenrel_name.co", 0, 0, 0}}}; config_map = &moe1_kernel_w8a8_block_configs; } else if (Mode == 3 || Mode == 7) { first_stage =false; static std::unordered_map moe2_kernel_w8a8_block_configs = { {20000, {"MT256x16x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x16x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 16, 256}}, {21000, {"MT256x32x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x32x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 32, 256}}, {22000, {"MT256x64x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x64x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 64, 256}}, {22001, {"MT256x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe2.co", 256, 64, 768}}, {22002, {"MT256x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe2.co", 256, 64, 768}}, {22003, {"MT256x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe2.co", 256, 64, 768}}, {22004, {"MT256x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe2.co", 256, 64, 768}}, {23000, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 128, 256}}, {23001, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds_1", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_1.co", 256, 128, 256}}, {24000, {"MT256x256x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x256x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 256, 256}}, {24001, {"MT256x256x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds_1", "w8a8_block/stage2/Cijk_Alik_Bljk_I8HS_MT256x256x128_SN_K1_SB3_SWMNK4_1_1_moe2_1.co", 256, 256, 256}}, {20100, {"MT1024x16x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT1024x16x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 768}}, {20101, {"MT3584x16x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT3584x16x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 16, 768}}, {20200, {"MT1024x16x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/256k/MT1024x16x256_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 768}}, {20201, {"MT3584x16x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/256k/MT3584x16x256_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 16, 768}}, {21100, {"MT1024x32x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT1024x32x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 32, 768}}, {21101, {"MT3584x32x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT3584x32x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 32, 768}}, {21200, {"MT1024x32x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/256k/MT1024x32x256_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 32, 768}}, {22100, {"MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 64, 768}}, {22101, {"MT3584x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT3584x64x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 64, 768}}, {23100, {"MT1024x128x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT1024x128x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 128, 768}}, {23101, {"MT3584x128x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT3584x128x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 128, 768}}, {24100, {"MT1024x256x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT1024x256x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 256, 768}}, {24101, {"MT3584x256x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w8a8_block/stage2/128k/MT3584x256x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 256, 768}}, {29999, {"kenrel_name", "w8a8_block/stage2/kenrel_name.co", 0, 0, 0}}}; config_map = &moe2_kernel_w8a8_block_configs; } else if (Mode == 10) { first_stage =true; first_stage_topk = top_k; static std::unordered_map moe1_kernel_w4a8_configs = { {10000, {"MT128x16x256_SN_K1_PGR4_WG16_16_3_moe1", "w4a8/Cijk_Alik_Bljk_I8HS_MT128x16x256_SN_K1_SB0_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10001, {"MT128x16x256_SN_K1_PGR2_WG16_16_3_moe1_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT128x16x256_SN_K1_SB0_SWMNK8_1_1_moe1_ScaleTolds0.co", 128, 16, 768}}, {10002, {"MT128x16x256_SN_K1_PGR3_WG16_16_3_moe1_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT128x16x256_SN_K1_SB0_SWMNK8_1_1_moe1_ScaleTolds1.co", 128, 16, 768}}, {10003, {"MT256x16x256_SN_K1_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT256x16x256_SN_K1_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {11000, {"MT128x32x256_SN_K1_PGR4_WG16_16_3_moe1", "w4a8/Cijk_Alik_Bljk_I8HS_MT128x32x256_SN_K1_SB0_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11001, {"MT128x32x256_SN_K1_PGR2_WG16_16_3_moe1_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT128x32x256_SN_K1_SB0_SWMNK8_1_1_moe1_ScaleTolds.co", 128, 32, 768}}, {11002, {"MT256x32x256_SN_K1_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT256x32x256_SN_K1_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {11009, {"MT64x32x256_SN_K1_PGR4_WG16_16_3_moe1", "w4a8/Cijk_Alik_Bljk_I8HS_MT64x32x256_SN_K1_SB0_SWMNK4_2_1_moe1.co", 64, 32, 768}}, {12000, {"MT256x64x256_SN_K1_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT256x64x256_SN_K1_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {13000, {"MT256x128x256_SN_K1_SB3_SWMNK8_1_1_moe1_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT256x128x256_SN_K1_SB3_SWMNK8_1_1_moe1.co", 256, 128, 768}}, {19999, {"kenrel_name", "w4a8/kenrel_name.co", 0, 0, 0}}}; config_map = &moe1_kernel_w4a8_configs; } else if (Mode == 11) { first_stage =false; static std::unordered_map moe2_kernel_w4a8_configs = { {20000, {"MT256x16x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT256x16x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 16, 256}}, {21000, {"MT256x32x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT256x32x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 32, 256}}, {22000, {"MT256x64x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT256x64x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 64, 256}}, {23000, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds", "w4a8/Cijk_Alik_Bljk_I8HS_MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 128, 256}}, {20100, {"MT1024x16x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w4a8/256k/MT1024x16x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 768}}, {20101, {"MT3584x16x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w4a8/256k/MT3584x16x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 16, 768}}, {21100, {"MT1024x32x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w4a8/256k/MT1024x32x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 32, 768}}, {21101, {"MT3584x32x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w4a8/256k/MT3584x32x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 32, 768}}, {22100, {"MT1024x64x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w4a8/256k/MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 64, 768}}, {22101, {"MT3584x64x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w4a8/256k/MT3584x64x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 64, 768}}, {23100, {"MT1024x128x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w4a8/256k/MT1024x128x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 128, 768}}, {23101, {"MT3584x128x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds", "w4a8/256k/MT3584x128x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 3584, 128, 768}}, {29999, {"kenrel_name_moe2", "w4a8/kenrel_name_moe2.co", 0, 0, 0}}}; config_map = &moe2_kernel_w4a8_configs; } else if (Mode == 20) { first_stage =true; first_stage_topk = top_k; static std::unordered_map moe1_kernel_a16_configs = { {10000, {"B16_MT32x16x128_SN_K1_PGR4_WG16_16_2_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT32x16x128_PGR4_SB3_SWMNK2_2_1_moe1.co", 32, 16, 384}}, {10001, {"B16_MT32x16x128_SN_K1_PGR5_WG16_16_2_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT32x16x128_PGR5_SB3_SWMNK2_2_1_moe1.co", 32, 16, 384}}, {10002, {"B16_MT64x16x128_SN_K1_PGR3_WG16_16_3_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT64x16x128_PGR3_SB3_SWMNK4_2_1_moe1.co", 64, 16, 512}}, {10003, {"B16_MT32x16x128_SN_K1_PGR4_SB3_SWMNK2_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT32x16x128_T_PGR4_SB3_SWMNK2_1_1_moe1.co", 32, 16, 384}}, {10004, {"B16_MT32x16x128_SN_K1_PGR5_SB3_SWMNK2_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT32x16x128_T_PGR5_SB3_SWMNK2_1_1_moe1.co", 32, 16, 384}}, {10005, {"B16_MT32x16x128_SN_K1_PGR6_SB3_SWMNK2_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT32x16x128_T_PGR6_SB3_SWMNK2_1_1_moe1.co", 32, 16, 384}}, {10006, {"B16_MT64x16x128_SN_K1_PGR4_SB3_SWMNK4_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT64x16x128_T_PGR4_SB3_SWMNK4_1_1_moe1.co", 64, 16, 512}}, {10007, {"B16_MT64x16x128_SN_K1_PGR5_SB3_SWMNK4_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT64x16x128_T_PGR5_SB3_SWMNK4_1_1_moe1.co", 64, 16, 512}}, {10008, {"B16_MT128x16x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x16x128_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10009, {"B16_MT128x16x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x16x128_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10010, {"B16_MT128x16x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x16x128_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 16, 768}}, {10011, {"B16_MT256x16x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x16x128_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {10012, {"B16_MT256x16x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x16x128_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {10013, {"B16_MT256x16x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x16x128_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 16, 768}}, {11000, {"B16_MT32x32x128_SN_K1_PGR4_WG16_16_2_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT32x32x128_SN_K1_SB3_SWMNK2_2_1_moe1.co", 32, 32, 512}}, {11002, {"B16_MT128x32x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x32x128_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11003, {"B16_MT256x32x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x32x128_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {11004, {"B16_MT128x32x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x32x128_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11005, {"B16_MT256x32x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x32x128_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {11006, {"B16_MT128x32x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x32x128_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 32, 768}}, {11007, {"B16_MT256x32x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x32x128_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 32, 768}}, {12000, {"B16_MT128x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x64x128_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 128, 64, 768}}, {12001, {"B16_MT256x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x64x128_T_PGR3_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {12002, {"B16_MT128x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x64x128_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 128, 64, 768}}, {12003, {"B16_MT256x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x64x128_T_PGR4_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {12004, {"B16_MT128x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x64x128_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 64, 768}}, {12005, {"B16_MT256x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x64x128_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 64, 768}}, {13000, {"B16_MT128x128x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT128x128x128_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 128, 128, 768}}, {13001, {"B16_MT256x128x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1", "w16a16/stage1/Cijk_Alik_Bljk_B16_MT256x128x128_T_PGR2_SB3_SWMNK8_1_1_moe1.co", 256, 128, 768}}, {19999, {"kenrel_name", "w4a8/kenrel_name.co", 0, 0, 0}}}; config_map = &moe1_kernel_a16_configs; } else if (Mode == 21) { first_stage =false; static std::unordered_map moe2_kernel_a16_configs = { {20000, {"B16_MT128x16x64_SN_K1_SB3_SWMNK4_1_1_moe2", "w16a16/stage2/Cijk_Alik_Bljk_B16_MT128x16x64_SN_K1_SB3_SWMNK4_1_1_moe2.co", 128, 16, 256}}, {20001, {"B16_MT128x16x64_PGR2_SB3_SWMNK4_1_1_moe2", "w16a16/stage2/Cijk_Alik_Bljk_B16_MT128x16x64_PGR2_SB3_SWMNK4_1_1_moe2.co", 128, 16, 256}}, {20002, {"B16_MT256x16x64_SN_K1_SB3_SWMNK4_1_1_moe2", "w16a16/stage2/Cijk_Alik_Bljk_B16_MT256x16x64_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 16, 256}}, {21001, {"B16_MT256x32x64_SN_K1_SB3_SWMNK4_1_1_moe2", "w16a16/stage2/Cijk_Alik_Bljk_B16_MT256x32x64_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 32, 256}}, {22001, {"B16_MT256x64x64_SN_K1_SB3_SWMNK4_1_1_moe2", "w16a16/stage2/Cijk_Alik_Bljk_B16_MT256x64x64_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 64, 256}}, {23000, {"B16_MT256x128x64_SN_K1_SB3_SWMNK4_1_1_moe2_1", "w16a16/stage2/Cijk_Alik_Bljk_B16_MT256x128x64_SN_K1_SB3_SWMNK4_1_1_moe2_1.co", 256, 128, 256}}, {23001, {"B16_MT256x128x64_SN_K1_SB3_SWMNK4_1_1_moe2", "w16a16/stage2/Cijk_Alik_Bljk_B16_MT256x128x64_SN_K1_SB3_SWMNK4_1_1_moe2.co", 256, 128, 256}}, {23002, {"B16_MT256x128x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/Cijk_Alik_Bljk_B16_MT256x128x128_T_PGR2_SB3_SWMNK8_1_1_moe2.co", 256, 128, 768}}, {20100, {"B16_MT1024x16x64_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT1024x16x64_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 768}}, {20101, {"B16_MT2048x16x64_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT2048x16x64_SN_K1_SB3_SWMNK8_1_1_moe2.co", 2048, 16, 768}}, {20200, {"B16_MT1024x16x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT1024x16x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 512}}, {20201, {"B16_MT2048x16x128_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT2048x16x128_SN_K1_SB3_SWMNK8_1_1_moe2.co", 2048, 16, 512}}, {20300, {"B16_MT1024x16x256_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT1024x16x256_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 512}}, {20301, {"B16_MT2048x16x256_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT2048x16x256_SN_K1_SB3_SWMNK8_1_1_moe2.co", 2048, 16, 512}}, {20400, {"B16_MT1024x16x384_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT1024x16x384_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 512}}, {20401, {"B16_MT2048x16x384_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT2048x16x384_SN_K1_SB3_SWMNK8_1_1_moe2.co", 2048, 16, 512}}, {20500, {"B16_MT1024x16x352_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT1024x16x352_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 512}}, {20600, {"B16_MT1024x16x320_SN_K1_SB3_SWMNK8_1_1_moe2", "w16a16/stage2/eval/MT1024x16x320_SN_K1_SB3_SWMNK8_1_1_moe2.co", 1024, 16, 512}}, {29999, {"kenrel_name_moe2", "w4a8/kenrel_name_moe2.co", 0, 0, 0}}}; config_map = &moe2_kernel_a16_configs; } //if (Mode >= 10000 && Mode <= 19999) { // first_stage =true; // first_stage_topk = top_k; // config_map = &moe1_kernel_int8_configs; // Moe1ModeLookup lookup; // Mode = lookup.getMode(input.size(0)); //} //else if (Mode >= 20000 && Mode <= 29999) { // first_stage =false; // config_map = &moe2_kernel_int8_configs; // Moe2ModeLookup lookup; // Mode = lookup.getMode(input.size(0) / first_stage_topk, input.size(1)); //} if (!config_map) { TORCH_CHECK(false, __func__, " Input not supput Mode: ", Mode); } int Solution = 0; if (first_stage) Solution = 10000; else Solution = 20000; if (solidx.has_value()) { Solution = solidx.value(); } //auto it = config_map->find(Mode); auto it = config_map->find(Solution); if (it != config_map->end()) { const auto &config = it->second; const char *name = config.name.c_str(); std::string config_co_name = config.co_name; if (Mode >= 4 && Mode <= 7) { size_t pos = config_co_name.find("w8a8"); if (pos != std::string::npos) { config_co_name.insert(pos, "f8/"); } } if ((Mode == 10 || Mode == 11) && OutType == 1) { size_t pos = config_co_name.find("w4a8/"); if (pos != std::string::npos) { config_co_name.insert(pos + 5, "bf16/"); } } if ((Mode == 20 || Mode == 21) && OutType == 1) { size_t pos = config_co_name.find("w16a16/"); if (pos != std::string::npos) { config_co_name.insert(pos + 7, "bf16/"); } } if (Shuffle != 0) { size_t pos = config_co_name.find("stage"); if (pos != std::string::npos) { config_co_name.insert(pos + 7, "shuffle/"); } if ((Mode == 0 || Mode == 4) && gate.size(1) % 128 == 64) { config_co_name.insert(pos + 14, "_N64"); } if ((Mode == 1 || Mode == 5) && down.size(2) % 128 == 96) { config_co_name.insert(pos + 14, "_K96"); } else if ((Mode == 1 || Mode == 5) && down.size(2) % 128 == 64) { config_co_name.insert(pos + 14, "_K64"); } } const char *co_name = config_co_name.c_str(); Config = {config.MT0, config.MT1, config.bdx}; auto result = impl_ptr_map.emplace(co_name, nullptr); if (result.second) { result.first->second = std::make_unique(name, co_name); } impl_ptr = result.first->second.get(); } TORCH_CHECK(impl_ptr != nullptr, __func__, ": unsupport current input Mode: ", Mode, ",Solution:", Solution); if (first_stage) { BlockSizeM = Config[1]; impl_ptr->launch_kernel(Config, out, input, gate, down, sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, top_k, OutType, PersistGroups, BlockSizeM, ScaleA, ScaleB, ZeroPoints); } else { if (BlockSizeM == 0) BlockSizeM = Config[1]; if (BlockSizeM < Config[1]) TORCH_CHECK(false, __func__," Currently only supports stage2 kernel MT1 <= stage1 kernel MT1, ", Config[1], " > ", BlockSizeM); impl_ptr->launch_kernel(Config, out, input, gate, down, sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, top_k, OutType, PersistGroups, BlockSizeM, ScaleA, ScaleB, ZeroPoints); } }