asm_gemm_kernel_config.cpp 7.92 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
// SPDX-License-Identifier: MIT
#include "asm_gemm_kernel_config.h"

#include <fstream>
#include <mutex>
#include <stdexcept>
#include <string>
#include <vector>
#include <iostream>

#include "nlohmann/json.hpp"
using json = nlohmann::json;

static std::once_flag g_once;
static json g_json;
static std::once_flag g_once_bf16;
static json g_json_bf16;


static void ensure_json_loaded(std::string& filename) {
    char delimiter = '/';
    size_t pos = filename.rfind(delimiter);
    std::string tfilename = filename;
    if (pos != std::string::npos) {
        tfilename = filename.substr(pos + 1);
    }

    if (tfilename.find("bf16") != std::string::npos) {
        std::call_once(g_once_bf16, [filename](){
        std::ifstream ifs(filename);
        if (!ifs) throw std::runtime_error("Cannot open json: " + filename);
        try {
            ifs >> g_json_bf16;
        } catch (const std::exception& e) {
            throw std::runtime_error(std::string("JSON parse error in ")
                                     + filename + ": " + e.what());
        }
        });
    } else {
        std::call_once(g_once, [filename](){
        std::ifstream ifs(filename);
        if (!ifs) throw std::runtime_error("Cannot open json: " + filename);
        try {
            ifs >> g_json;
        } catch (const std::exception& e) {
            throw std::runtime_error(std::string("JSON parse error in ")
                                     + filename + ": " + e.what());
        }
        });
    }
    
}


KernelCfg get_kernel_cfg_by_index(int index, std::string& filename)
{
    ensure_json_loaded(filename);

    KernelCfg cfg;
    
    char delimiter = '/';
    size_t pos = filename.rfind(delimiter);
    std::string tfilename = filename;
    if (pos != std::string::npos) {
        tfilename = filename.substr(pos + 1);
    }

    if (tfilename.find("bf16") != std::string::npos) {
        cfg.kernel_name = g_json_bf16["kernels"][index]["kernel_name"];
        cfg.co_file     = g_json_bf16["kernels"][index]["co_file"];
        cfg.mt0         = g_json_bf16["kernels"][index]["Kconfigs"]["mt0"];
        cfg.mt1         = g_json_bf16["kernels"][index]["Kconfigs"]["mt1"];
        cfg.numThreads  = g_json_bf16["kernels"][index]["Kconfigs"]["numThreads"];
        cfg.wgm         = g_json_bf16["kernels"][index]["Kconfigs"]["wgm"];
    } else {
        cfg.kernel_name = g_json["kernels"][index]["kernel_name"];
        cfg.co_file     = g_json["kernels"][index]["co_file"];
        cfg.mt0         = g_json["kernels"][index]["Kconfigs"]["mt0"];
        cfg.mt1         = g_json["kernels"][index]["Kconfigs"]["mt1"];
        cfg.numThreads  = g_json["kernels"][index]["Kconfigs"]["numThreads"];
        cfg.wgm         = g_json["kernels"][index]["Kconfigs"]["wgm"];
    }
    

    return cfg;
}

static int solcol;
KernelCfg get_kernel_cfg_by_csv(MatchProblem& prob, std::string& filename)
{
    const char* aiter_root = std::getenv("AITER_META_DIR");
    if (aiter_root == nullptr)
        std::cerr << "[ERROR]ENV AITER_META_DIR not set" << std::endl;
    std::string root_dir(aiter_root);
    std::string jsonfile = root_dir + "/aiter/configs/asm_tune/" + filename;
    ensure_json_loaded(jsonfile);

    std::string tunedCSV = " ";
    if (filename.find("bf16") != std::string::npos) {
        tunedCSV = g_json_bf16["tunedCSV"];
    } else {
        tunedCSV = g_json["tunedCSV"];
    }
    std::string csvfile = root_dir + "/aiter/configs/asm_tune/" + tunedCSV;

    auto data = read_csv(csvfile, &solcol);
    int solutionid = 0;
    
    for (const auto& row : data) {
        if(std::stoi(row[0]) == prob.M && std::stoi(row[1]) == prob.N && std::stoi(row[2]) == prob.K)
        {
            solutionid = std::stoi(row[solcol]);
            break;
        }
    }

    KernelCfg cfg;
    if (filename.find("bf16") != std::string::npos) {
        cfg.kernel_name = g_json_bf16["kernels"][solutionid]["kernel_name"];
        cfg.co_file     = g_json_bf16["kernels"][solutionid]["co_file"];
        cfg.mt0         = g_json_bf16["kernels"][solutionid]["Kconfigs"]["mt0"];
        cfg.mt1         = g_json_bf16["kernels"][solutionid]["Kconfigs"]["mt1"];
        cfg.numThreads  = g_json_bf16["kernels"][solutionid]["Kconfigs"]["numThreads"];
        cfg.wgm         = g_json_bf16["kernels"][solutionid]["Kconfigs"]["wgm"];
    } else {
        cfg.kernel_name = g_json["kernels"][solutionid]["kernel_name"];
        cfg.co_file     = g_json["kernels"][solutionid]["co_file"];
        cfg.mt0         = g_json["kernels"][solutionid]["Kconfigs"]["mt0"];
        cfg.mt1         = g_json["kernels"][solutionid]["Kconfigs"]["mt1"];
        cfg.numThreads  = g_json["kernels"][solutionid]["Kconfigs"]["numThreads"];
        cfg.wgm         = g_json["kernels"][solutionid]["Kconfigs"]["wgm"];
    }

    return cfg;    
    
}



std::vector<std::vector<std::string>> read_csv(const std::string& filename, int* solidx) 
{
    static std::once_flag flag;
    static std::vector<std::vector<std::string>> result;

    static std::once_flag flag_bf16;
    static std::vector<std::vector<std::string>> result_bf16;

    char delimiter = '/';
    size_t pos = filename.rfind(delimiter);
    std::string tfilename = filename;
    if (pos != std::string::npos) {
        tfilename = filename.substr(pos + 1);
    }

    if (tfilename.find("bf16") != std::string::npos) {

        std::call_once(flag_bf16, [filename, solidx]() {
            // 原有的CSV读取逻辑
            std::ifstream file(filename);
            std::string line;
            std::string headers;
            
            if (!file.is_open()) {
                throw std::runtime_error("Can not open:" + filename);
            }
            
            //找出solidx 是第几列
            std::getline(file, headers);
            std::stringstream head(headers);
            std::string tmpcell;
            int index = 0;
            while (std::getline(head, tmpcell, ',')) {
                if (tmpcell.compare("solidx") == 0)
                {
                    *solidx = static_cast<int>(index);
                    break;
                }
                index++;
            }

            //读出数据
            while (std::getline(file, line)) {
                std::vector<std::string> row;
                std::stringstream ss(line);
                std::string cell;
                
                while (std::getline(ss, cell, ',')) {
                    row.push_back(cell);
                }
                
                result_bf16.push_back(row);
            }
            
            file.close();
            std::cout << "Reading bf16 CSV file completed!!!" << std::endl;
        });

        return result_bf16;

    } else {
        std::call_once(flag, [filename, solidx]() {
            // 原有的CSV读取逻辑
            std::ifstream file(filename);
            std::string line;
            std::string headers;
            
            if (!file.is_open()) {
                throw std::runtime_error("Can not open: " + filename);
            }
            
            //找出solidx 是第几列
            std::getline(file, headers);
            std::stringstream head(headers);
            std::string tmpcell;
            int index = 0;
            while (std::getline(head, tmpcell, ',')) {
                if (tmpcell.compare("solidx") == 0)
                {
                    *solidx = static_cast<int>(index);
                    break;
                }
                index++;
            }

            //读出数据
            while (std::getline(file, line)) {
                std::vector<std::string> row;
                std::stringstream ss(line);
                std::string cell;
                
                while (std::getline(ss, cell, ',')) {
                    row.push_back(cell);
                }
                
                result.push_back(row);
            }
            
            file.close();
            std::cout << "Reading fp16 CSV file completed!!!" << std::endl;
        });

        return result;
    }
    
}