// SPDX-License-Identifier: MIT #include #include #include #include #include "aiter_hip_common.h" #include "moe_op.h" //#include "py_itfs_common.h" #include #include #include #define USE_SHUFFLE 1 std::vector get_w4a16_solutions() { static std::unordered_map moe1_kernel_w4a16_configs = { {10000, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT16x128x128_SN_K1_PGR4_TT1_8_WG16_16_3"}}, {11001, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT32x128x128_SN_K1_PGR4_TT2_8_WG16_16_3"}}, {11002, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT32x128x128_SN_K1_PGR4_TT2_8_WG16_16_3"}}}; //{11000, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT32x128x128_SN_K1_PGR4_TT2_8_WG16_16_3"}} kernel failed static std::unordered_map moe2_kernel_w4a16_configs = { {20000, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT16x3584x16_SN_K1_TT1_224_WG16_16_3_WGM1"}}, {20001, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT16x7168x16_SN_K1_TT1_448_WG16_16_3_WGM1"}}, {20002, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT16x1024x16_SN_K1_TT1_64_WG16_16_3_WGM1"}}, {21001, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT32x1024x16_SN_K1_TT2_64_WG16_16_3_WGM1"}}, {21002, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT32x1024x16_SN_K1_TT2_64_WG16_16_3_WGM1"}}}; //{21000, {"Cijk_Alik_Bljk_HHS_BH_UserArgs_MT32x7168x16_SN_K1_TT2_448_WG16_16_3_WGM1"}} kernel failed std::vector validSolutions; std::vector> rangeRules = { {10000, 20000}, {11000, 21000} }; for (const auto& rule : rangeRules) { int config1Start = rule.first; int config2Start = rule.second; int configSpan = 1000; for (const auto& pair1 : moe1_kernel_w4a16_configs) { int key1 = pair1.first; if (key1 >= config1Start && key1 < config1Start + configSpan) { for (const auto& pair2 : moe2_kernel_w4a16_configs) { int key2 = pair2.first; if (key2 >= config2Start && key2 < config2Start + configSpan) { std::string combined = std::to_string(key1) + "+" + std::to_string(key2); validSolutions.push_back(combined); } } } } } return validSolutions; } std::vector get_w4a8_solutions(int hdim_size=0) { static std::unordered_map moe1_kernel_w4a8_configs = { {10000, {"MT128x16x256_SN_K1_PGR4_WG16_16_3_moe1"}}, {10001, {"MT128x16x256_SN_K1_PGR2_WG16_16_3_moe1_ScaleTolds"}}, {10002, {"MT128x16x256_SN_K1_PGR3_WG16_16_3_moe1_ScaleTolds"}}, {10003, {"MT256x16x256_SN_K1_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {11000, {"MT128x32x256_SN_K1_PGR4_WG16_16_3_moe1"}}, {11001, {"MT128x32x256_SN_K1_PGR2_WG16_16_3_moe1_ScaleTolds"}}, {11002, {"MT256x32x256_SN_K1_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {11009, {"MT64x32x256_SN_K1_PGR4_WG16_16_3_moe1"}}, {12000, {"MT256x64x256_SN_K1_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {13000, {"MT256x128x256_SN_K1_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}}; static std::unordered_map moe2_kernel_w4a8_configs = { {20000, {"MT256x16x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds"}}, {21000, {"MT256x32x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds"}}, {22000, {"MT256x64x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds"}}, {23000, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds"}}, {20100, {"MT1024x16x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {20101, {"MT3584x16x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {21100, {"MT1024x32x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {21101, {"MT3584x32x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {22100, {"MT1024x64x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {22101, {"MT3584x64x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {23100, {"MT1024x128x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {23101, {"MT3584x128x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}}; std::vector validSolutions; std::vector> rangeRules = { {10000, 20000}, {11000, 20000}, {11000, 21000}, {12000, 20000}, {12000, 21000}, {12000, 22000}, {13000, 21000}, {13000, 22000}, {13000, 23000}, }; if (hdim_size == 256) { for (const auto& rule : rangeRules) { int config1Start = rule.first; int config2Start = rule.second; int configSpan = 1000; for (const auto& pair1 : moe1_kernel_w4a8_configs) { int key1 = pair1.first; if (key1 >= config1Start && key1 < config1Start + configSpan) { for (const auto& pair2 : moe2_kernel_w4a8_configs) { int key2 = pair2.first; if (key2 >= config2Start && key2 < config2Start + configSpan) { std::string combined = std::to_string(key1) + "+" + std::to_string(key2); validSolutions.push_back(combined); } } } } } } else { for (const auto& rule : rangeRules) { int config1Start = rule.first; int config2Start = rule.second; int config1Span = 1000; int config2Span = 100; for (const auto& pair1 : moe1_kernel_w4a8_configs) { int key1 = pair1.first; if (key1 >= config1Start && key1 < config1Start + config1Span) { for (const auto& pair2 : moe2_kernel_w4a8_configs) { int key2 = pair2.first; if (key2 >= config2Start && key2 < config2Start + config2Span) { std::string combined = std::to_string(key1) + "+" + std::to_string(key2); validSolutions.push_back(combined); } } } } } } return validSolutions; } std::vector get_w8a8_solutions(int hdim_size=0) { static std::unordered_map moe1_kernel_int8_configs = { {10000, {"MT32x16x256_SN_K1_PGR4_WG16_16_2_moe1"}}, {10001, {"MT32x16x256_SN_K1_PGR5_WG16_16_2_moe1"}}, {10002, {"MT64x16x256_SN_K1_PGR3_WG16_16_3_moe1"}}, {10003, {"MT32x16x256_SN_K1_PGR4_SB3_SWMNK2_1_1_moe1"}}, {10004, {"MT32x16x256_SN_K1_PGR5_SB3_SWMNK2_1_1_moe1"}}, {10005, {"MT32x16x256_SN_K1_PGR6_SB3_SWMNK2_1_1_moe1"}}, {10006, {"MT64x16x256_SN_K1_PGR4_SB3_SWMNK4_1_1_moe1"}}, {10007, {"MT64x16x256_SN_K1_PGR5_SB3_SWMNK4_1_1_moe1"}}, {10008, {"MT128x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {10009, {"MT128x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {10010, {"MT128x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {10011, {"MT256x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {10012, {"MT256x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {10013, {"MT256x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {11000, {"MT32x32x256_SN_K1_PGR4_WG16_16_2_moe1"}}, {11001, {"MT64x32x256_SN_K1_PGR2_WG16_16_3_moe1"}}, {11002, {"MT128x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {11003, {"MT256x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {11004, {"MT128x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {11005, {"MT256x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {11006, {"MT128x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {11007, {"MT256x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {12000, {"MT128x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {12001, {"MT256x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {12002, {"MT128x64x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {12003, {"MT256x64x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {12004, {"MT128x64x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {12005, {"MT256x64x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {13000, {"MT128x128x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {13001, {"MT256x128x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}}; static std::unordered_map moe2_kernel_int8_configs = { {20000, {"MT128x16x128_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {20001, {"MT256x16x128_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {21000, {"MT128x32x128_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {21001, {"MT256x32x128_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {22000, {"MT128x64x128_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {22001, {"MT256x64x128_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {23000, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_1"}}, {23001, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {23002, {"MT128x128x128_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {20100, {"MT1024x16x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}, {20101, {"MT2048x16x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}, {20102, {"MT3584x16x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}, {21100, {"MT1024x32x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}, {21101, {"MT2048x32x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}, {21102, {"MT3584x32x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}, {22100, {"MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}, {22101, {"MT2048x64x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}, {22102, {"MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_1"}}, {22103, {"MT2048x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_1"}}, {23100, {"MT1024x128x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}, {23101, {"MT2048x128x128_SN_K1_SB3_SWMNK8_1_1_moe2"}}}; std::vector validSolutions; std::vector> rangeRules = { {10000, 20000}, {11000, 20000}, {11000, 21000}, {12000, 20000}, {12000, 21000}, {12000, 22000}, {13000, 21000}, {13000, 22000}, {13000, 23000}, }; if (hdim_size == 128) { for (const auto& rule : rangeRules) { int config1Start = rule.first; int config2Start = rule.second; int configSpan = 1000; for (const auto& pair1 : moe1_kernel_int8_configs) { int key1 = pair1.first; if (key1 >= config1Start && key1 < config1Start + configSpan) { for (const auto& pair2 : moe2_kernel_int8_configs) { int key2 = pair2.first; if (key2 >= config2Start && key2 < config2Start + configSpan) { std::string combined = std::to_string(key1) + "+" + std::to_string(key2); validSolutions.push_back(combined); } } } } } } else { for (const auto& rule : rangeRules) { int config1Start = rule.first; int config2Start = rule.second; int config1Span = 1000; int config2Span = 100; for (const auto& pair1 : moe1_kernel_int8_configs) { int key1 = pair1.first; if (key1 >= config1Start && key1 < config1Start + config1Span) { for (const auto& pair2 : moe2_kernel_int8_configs) { int key2 = pair2.first; if (key2 >= config2Start && key2 < config2Start + config2Span) { std::string combined = std::to_string(key1) + "+" + std::to_string(key2); validSolutions.push_back(combined); } } } } } } return validSolutions; } std::vector get_w8a8_g_solutions(int hdim_size=0) { static std::unordered_map moe1_kernel_w8a8_block_configs = { {10000, {"MT128x16x128_SN_K1_PGR3_WG16_16_3_moe1_ScaleTolds"}}, {10001, {"MT128x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {10002, {"MT128x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {10003, {"MT128x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {10004, {"MT256x16x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {10005, {"MT256x16x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {10006, {"MT256x16x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {10007, {"MT64x16x256_SN_K1_PGR4_SB3_SWMNK4_1_1_moe1_ScaleTolds"}}, {10008, {"MT32x16x256_SN_K1_PGR4_SB3_SWMNK2_1_1_moe1_ScaleTolds"}}, {11000, {"MT32x32x256_SN_K1_PGR4_WG16_16_2_moe1"}}, {11001, {"MT64x32x256_SN_K1_PGR2_WG16_16_3_moe1"}}, {11002, {"MT32x32x256_SN_K1_PGR3_WG16_16_2_moe1_ScaleTolds"}}, {11003, {"MT64x32x256_SN_K1_PGR2_WG16_16_3_moe1_ScaleTolds"}}, {11004, {"MT64x32x128_SN_K1_PGR4_WG16_16_3_moe1_ScaleTolds"}}, {11005, {"MT128x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {11006, {"MT128x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {11007, {"MT128x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {11008, {"MT256x32x256_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {11009, {"MT256x32x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {11010, {"MT256x32x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {12000, {"MT128x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {12001, {"MT128x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {12002, {"MT256x64x256_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {12003, {"MT256x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {12004, {"MT256x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {12005, {"MT256x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {12006, {"MT256x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {13000, {"MT128x128x256_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {13001, {"MT256x128x256_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {14000, {"MT128x256x256_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}, {14001, {"MT256x256x256_SB3_SWMNK8_1_1_moe1_ScaleTolds"}}}; static std::unordered_map moe2_kernel_w8a8_block_configs = { {20000, {"MT256x16x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds"}}, {21000, {"MT256x32x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds"}}, {22000, {"MT256x64x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds"}}, {22001, {"MT256x64x256_SN_K1_PGR2_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {22002, {"MT256x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {22003, {"MT256x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {22004, {"MT256x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {23000, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds"}}, {23001, {"MT256x128x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds_1"}}, {24000, {"MT256x256x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds"}}, {24001, {"MT256x256x128_SN_K1_SB3_SWMNK4_1_1_moe2_ScaleTolds_1"}}, {20100, {"MT1024x16x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {20101, {"MT3584x16x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {21100, {"MT1024x32x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {21101, {"MT3584x32x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {22100, {"MT1024x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {22101, {"MT3584x64x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {23100, {"MT1024x128x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {23101, {"MT3584x128x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {24100, {"MT1024x256x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {24101, {"MT3584x256x128_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {20200, {"MT1024x16x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}, {21200, {"MT1024x32x256_SN_K1_SB3_SWMNK8_1_1_moe2_ScaleTolds"}}}; std::vector validSolutions; std::vector> rangeRules = { {10000, 20000}, {11000, 20000}, {11000, 21000}, {12000, 20000}, {12000, 21000}, {12000, 22000}, {13000, 21000}, {13000, 22000}, {13000, 23000}, {14000, 22000}, {14000, 23000}, {14000, 24000}, }; //Shuffle does not have the following solution id #ifdef USE_SHUFFLE moe1_kernel_w8a8_block_configs.erase(11000); moe1_kernel_w8a8_block_configs.erase(11001); #endif if (hdim_size == 128) { moe2_kernel_w8a8_block_configs.erase(22001); for (const auto& rule : rangeRules) { int config1Start = rule.first; int config2Start = rule.second; int config1Span = 1000; int config2Span = 200; for (const auto& pair1 : moe1_kernel_w8a8_block_configs) { int key1 = pair1.first; if (key1 >= config1Start && key1 < config1Start + config1Span) { for (const auto& pair2 : moe2_kernel_w8a8_block_configs) { int key2 = pair2.first; if (key2 >= config2Start && key2 < config2Start + config2Span) { std::string combined = std::to_string(key1) + "+" + std::to_string(key2); validSolutions.push_back(combined); } } } } } } if (hdim_size == 256) { for (const auto& rule : rangeRules) { int config1Start = rule.first; int config1Span = 1000; int config2Start1 = rule.second; int config2End1 = config2Start1 + 100; int config2Start2 = rule.second + 200; int config2End2 = config2Start2 + 100; for (const auto& pair1 : moe1_kernel_w8a8_block_configs) { int key1 = pair1.first; if (key1 >= config1Start && key1 < config1Start + config1Span) { for (const auto& pair2 : moe2_kernel_w8a8_block_configs) { int key2 = pair2.first; if ((key2 >= config2Start1 && key2 < config2End1) || (key2 >= config2Start2 && key2 < config2End2)) { std::string combined = std::to_string(key1) + "+" + std::to_string(key2); validSolutions.push_back(combined); } } } } } } else { for (const auto& rule : rangeRules) { int config1Start = rule.first; int config2Start = rule.second; int config1Span = 1000; int config2Span = 100; for (const auto& pair1 : moe1_kernel_w8a8_block_configs) { int key1 = pair1.first; if (key1 >= config1Start && key1 < config1Start + config1Span) { for (const auto& pair2 : moe2_kernel_w8a8_block_configs) { int key2 = pair2.first; if (key2 >= config2Start && key2 < config2Start + config2Span) { std::string combined = std::to_string(key1) + "+" + std::to_string(key2); validSolutions.push_back(combined); } } } } } } return validSolutions; } std::vector get_w16a16_solutions() { static std::unordered_map moe1_kernel_w16a16_configs = { {10000, {"B16_MT32x16x128_SN_K1_PGR4_WG16_16_2_moe1"}}, {10001, {"B16_MT32x16x128_SN_K1_PGR5_WG16_16_2_moe1"}}, {10002, {"B16_MT64x16x128_SN_K1_PGR3_WG16_16_3_moe1"}}, {10003, {"B16_MT32x16x128_SN_K1_PGR4_SB3_SWMNK2_1_1_moe1"}}, {10004, {"B16_MT32x16x128_SN_K1_PGR5_SB3_SWMNK2_1_1_moe1"}}, {10005, {"B16_MT32x16x128_SN_K1_PGR6_SB3_SWMNK2_1_1_moe1"}}, {10006, {"B16_MT64x16x128_SN_K1_PGR4_SB3_SWMNK4_1_1_moe1"}}, {10007, {"B16_MT64x16x128_SN_K1_PGR5_SB3_SWMNK4_1_1_moe1"}}, {10008, {"B16_MT128x16x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {10009, {"B16_MT128x16x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {10010, {"B16_MT128x16x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {10011, {"B16_MT256x16x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {10012, {"B16_MT256x16x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {10013, {"B16_MT256x16x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {11000, {"B16_MT32x32x128_SN_K1_PGR4_WG16_16_2_moe1"}}, {11002, {"B16_MT128x32x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {11003, {"B16_MT256x32x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {11004, {"B16_MT128x32x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {11005, {"B16_MT256x32x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {11006, {"B16_MT128x32x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {11007, {"B16_MT256x32x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {12000, {"B16_MT128x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {12001, {"B16_MT256x64x128_SN_K1_PGR3_SB3_SWMNK8_1_1_moe1"}}, {12002, {"B16_MT128x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {12003, {"B16_MT256x64x128_SN_K1_PGR4_SB3_SWMNK8_1_1_moe1"}}, {12004, {"B16_MT128x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {12005, {"B16_MT256x64x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {13000, {"B16_MT128x128x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}, {13001, {"B16_MT256x128x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe1"}}}; static std::unordered_map moe2_kernel_w16a16_configs = { {20000, {"B16_MT128x16x64_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {20001, {"B16_MT128x16x64_PGR2_SB3_SWMNK4_1_1_moe2"}}, {20002, {"B16_MT256x16x64_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {21001, {"B16_MT256x32x64_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {22001, {"B16_MT256x64x64_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {23000, {"B16_MT256x128x64_SN_K1_SB3_SWMNK4_1_1_moe2_1"}}, {23001, {"B16_MT256x128x64_SN_K1_SB3_SWMNK4_1_1_moe2"}}, {23002, {"B16_MT256x128x128_SN_K1_PGR2_SB3_SWMNK8_1_1_moe2"}}}; std::vector validSolutions; std::vector> rangeRules = { {10000, 20000}, {11000, 20000}, {11000, 21000}, {12000, 20000}, {12000, 21000}, {12000, 22000}, {13000, 21000}, {13000, 22000}, {13000, 23000}, }; for (const auto& rule : rangeRules) { int config1Start = rule.first; int config2Start = rule.second; int configSpan = 1000; for (const auto& pair1 : moe1_kernel_w16a16_configs) { int key1 = pair1.first; if (key1 >= config1Start && key1 < config1Start + configSpan) { for (const auto& pair2 : moe2_kernel_w16a16_configs) { int key2 = pair2.first; if (key2 >= config2Start && key2 < config2Start + configSpan) { std::string combined = std::to_string(key1) + "+" + std::to_string(key2); validSolutions.push_back(combined); } } } } } return validSolutions; } std::vector asm_moe_get_solutions(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &w2, // [e, n, k], pre-shuffle([e, nr, kr, w]) torch::Tensor &topk_weights, // [tokens, topk] torch::Tensor &topk_ids, // [tokens, topk] std::optional use_int8_w8a16, // use int8 w8a16 quantization std::optional use_int4_w4a16, // use int4 w4a16 quantization std::optional use_int8_w8a8, // use int8 w8a8 quantization std::optional use_int4_w4a8, // use int4 w4a8 quantization std::optional use_fp8_w8a8, // use f8 w8a8 quantization std::optional per_channel_quant, // use channel quantization std::optional w1_zp, // [e, 2*n, k/group], gate(up) zero-point std::optional w2_zp, // [e, k, n/group], down zero-point std::optional w1_scale, // [e, 1, n], gate(up) scale or ... std::optional w2_scale, // [e, 1, k], down scale or ... std::optional a1_scale, // [m, 1], token scale std::optional a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input std::optional block_shape_n, // quant block n size std::optional block_shape_k, // quant block k size std::optional block_m = 32, // moe partion size for tokens in m direction std::optional expert_mask = std::nullopt) { int experts = w1.size(0); int topk = topk_ids.size(1); int tokens = topk_ids.size(0); int hidden_size = w1.size(2); int hdim_size = w2.size(2); int block_size = block_m.has_value() ? block_m.value() : 0; if (use_int4_w4a16.has_value() && use_int4_w4a16.value()) { return get_w4a16_solutions(); } else if ((use_int8_w8a8.has_value() && use_int8_w8a8.value() && per_channel_quant.has_value() && not per_channel_quant.value()) || (use_fp8_w8a8.has_value() && use_fp8_w8a8.value() && per_channel_quant.has_value() && not per_channel_quant.value()) ) { return get_w8a8_g_solutions(hdim_size); } else if ((use_int8_w8a8.has_value() && use_int8_w8a8.value() && per_channel_quant.has_value() && per_channel_quant.value()) || (use_fp8_w8a8.has_value() && use_fp8_w8a8.value() && per_channel_quant.has_value() && per_channel_quant.value()) ) { return get_w8a8_solutions(hdim_size); } else if (use_int4_w4a8.has_value() && use_int4_w4a8.value()) { return get_w4a8_solutions(hdim_size*2); } else { return get_w16a16_solutions(); } }