// SPDX-License-Identifier: MIT #include "asm_gemm_kernel_config.h" #include #include #include #include #include #include #include "nlohmann/json.hpp" using json = nlohmann::json; static std::once_flag g_once; static json g_json; static std::once_flag g_once_bf16; static json g_json_bf16; static void ensure_json_loaded(std::string& filename) { char delimiter = '/'; size_t pos = filename.rfind(delimiter); std::string tfilename = filename; if (pos != std::string::npos) { tfilename = filename.substr(pos + 1); } if (tfilename.find("bf16") != std::string::npos) { std::call_once(g_once_bf16, [filename](){ std::ifstream ifs(filename); if (!ifs) throw std::runtime_error("Cannot open json: " + filename); try { ifs >> g_json_bf16; } catch (const std::exception& e) { throw std::runtime_error(std::string("JSON parse error in ") + filename + ": " + e.what()); } }); } else { std::call_once(g_once, [filename](){ std::ifstream ifs(filename); if (!ifs) throw std::runtime_error("Cannot open json: " + filename); try { ifs >> g_json; } catch (const std::exception& e) { throw std::runtime_error(std::string("JSON parse error in ") + filename + ": " + e.what()); } }); } } KernelCfg get_kernel_cfg_by_index(int index, std::string& filename) { ensure_json_loaded(filename); KernelCfg cfg; char delimiter = '/'; size_t pos = filename.rfind(delimiter); std::string tfilename = filename; if (pos != std::string::npos) { tfilename = filename.substr(pos + 1); } if (tfilename.find("bf16") != std::string::npos) { cfg.kernel_name = g_json_bf16["kernels"][index]["kernel_name"]; cfg.co_file = g_json_bf16["kernels"][index]["co_file"]; cfg.mt0 = g_json_bf16["kernels"][index]["Kconfigs"]["mt0"]; cfg.mt1 = g_json_bf16["kernels"][index]["Kconfigs"]["mt1"]; cfg.numThreads = g_json_bf16["kernels"][index]["Kconfigs"]["numThreads"]; cfg.wgm = g_json_bf16["kernels"][index]["Kconfigs"]["wgm"]; } else { cfg.kernel_name = g_json["kernels"][index]["kernel_name"]; cfg.co_file = g_json["kernels"][index]["co_file"]; cfg.mt0 = g_json["kernels"][index]["Kconfigs"]["mt0"]; cfg.mt1 = g_json["kernels"][index]["Kconfigs"]["mt1"]; cfg.numThreads = g_json["kernels"][index]["Kconfigs"]["numThreads"]; cfg.wgm = g_json["kernels"][index]["Kconfigs"]["wgm"]; } return cfg; } static int solcol; KernelCfg get_kernel_cfg_by_csv(MatchProblem& prob, std::string& filename) { const char* aiter_root = std::getenv("AITER_META_DIR"); if (aiter_root == nullptr) std::cerr << "[ERROR]ENV AITER_META_DIR not set" << std::endl; std::string root_dir(aiter_root); std::string jsonfile = root_dir + "/aiter/configs/asm_tune/" + filename; ensure_json_loaded(jsonfile); std::string tunedCSV = " "; if (filename.find("bf16") != std::string::npos) { tunedCSV = g_json_bf16["tunedCSV"]; } else { tunedCSV = g_json["tunedCSV"]; } std::string csvfile = root_dir + "/aiter/configs/asm_tune/" + tunedCSV; auto data = read_csv(csvfile, &solcol); int solutionid = 0; for (const auto& row : data) { if(std::stoi(row[0]) == prob.M && std::stoi(row[1]) == prob.N && std::stoi(row[2]) == prob.K) { solutionid = std::stoi(row[solcol]); break; } } KernelCfg cfg; if (filename.find("bf16") != std::string::npos) { cfg.kernel_name = g_json_bf16["kernels"][solutionid]["kernel_name"]; cfg.co_file = g_json_bf16["kernels"][solutionid]["co_file"]; cfg.mt0 = g_json_bf16["kernels"][solutionid]["Kconfigs"]["mt0"]; cfg.mt1 = g_json_bf16["kernels"][solutionid]["Kconfigs"]["mt1"]; cfg.numThreads = g_json_bf16["kernels"][solutionid]["Kconfigs"]["numThreads"]; cfg.wgm = g_json_bf16["kernels"][solutionid]["Kconfigs"]["wgm"]; } else { cfg.kernel_name = g_json["kernels"][solutionid]["kernel_name"]; cfg.co_file = g_json["kernels"][solutionid]["co_file"]; cfg.mt0 = g_json["kernels"][solutionid]["Kconfigs"]["mt0"]; cfg.mt1 = g_json["kernels"][solutionid]["Kconfigs"]["mt1"]; cfg.numThreads = g_json["kernels"][solutionid]["Kconfigs"]["numThreads"]; cfg.wgm = g_json["kernels"][solutionid]["Kconfigs"]["wgm"]; } return cfg; } std::vector> read_csv(const std::string& filename, int* solidx) { static std::once_flag flag; static std::vector> result; static std::once_flag flag_bf16; static std::vector> result_bf16; char delimiter = '/'; size_t pos = filename.rfind(delimiter); std::string tfilename = filename; if (pos != std::string::npos) { tfilename = filename.substr(pos + 1); } if (tfilename.find("bf16") != std::string::npos) { std::call_once(flag_bf16, [filename, solidx]() { // 原有的CSV读取逻辑 std::ifstream file(filename); std::string line; std::string headers; if (!file.is_open()) { throw std::runtime_error("Can not open:" + filename); } //找出solidx 是第几列 std::getline(file, headers); std::stringstream head(headers); std::string tmpcell; int index = 0; while (std::getline(head, tmpcell, ',')) { if (tmpcell.compare("solidx") == 0) { *solidx = static_cast(index); break; } index++; } //读出数据 while (std::getline(file, line)) { std::vector row; std::stringstream ss(line); std::string cell; while (std::getline(ss, cell, ',')) { row.push_back(cell); } result_bf16.push_back(row); } file.close(); std::cout << "Reading bf16 CSV file completed!!!" << std::endl; }); return result_bf16; } else { std::call_once(flag, [filename, solidx]() { // 原有的CSV读取逻辑 std::ifstream file(filename); std::string line; std::string headers; if (!file.is_open()) { throw std::runtime_error("Can not open: " + filename); } //找出solidx 是第几列 std::getline(file, headers); std::stringstream head(headers); std::string tmpcell; int index = 0; while (std::getline(head, tmpcell, ',')) { if (tmpcell.compare("solidx") == 0) { *solidx = static_cast(index); break; } index++; } //读出数据 while (std::getline(file, line)) { std::vector row; std::stringstream ss(line); std::string cell; while (std::getline(ss, cell, ',')) { row.push_back(cell); } result.push_back(row); } file.close(); std::cout << "Reading fp16 CSV file completed!!!" << std::endl; }); return result; } }