profile_gemm.cpp 15.1 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
5
6
7
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
Chao Liu's avatar
Chao Liu committed
8
9

#include "profiler/include/profile_gemm_impl.hpp"
Chao Liu's avatar
Chao Liu committed
10

Chao Liu's avatar
Chao Liu committed
11
enum struct GemmMatrixLayout
Chao Liu's avatar
Chao Liu committed
12
13
14
15
16
17
18
19
20
21
22
{
    MK_KN_MN, // 0
    MK_NK_MN, // 1
    KM_KN_MN, // 2
    KM_NK_MN, // 3
    MK_KN_NM, // 4
    MK_NK_NM, // 5
    KM_KN_NM, // 6
    KM_NK_NM, // 7
};

Chao Liu's avatar
Chao Liu committed
23
enum struct GemmDataType
Chao Liu's avatar
Chao Liu committed
24
{
25
26
27
28
    F32_F32_F32,    // 0
    F16_F16_F16,    // 1
    BF16_BF16_BF16, // 2
    INT8_INT8_INT8, // 3
Chao Liu's avatar
Chao Liu committed
29
30
31
32
};

int profile_gemm(int argc, char* argv[])
{
ltqin's avatar
ltqin committed
33
    if(!(argc == 14 || argc == 15))
Chao Liu's avatar
Chao Liu committed
34
35
    {
        printf("arg1: tensor operation (gemm: GEMM)\n");
36
        printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
Chao Liu's avatar
Chao Liu committed
37
38
        printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
        printf("                     1: A[m, k] * B[n, k] = C[m, n];\n");
ltqin's avatar
ltqin committed
39
40
        printf("                     2: A[k, m] * B[k, n] = C[m, n];\n");
        printf("                     3: A[k, m] * B[n, k] = C[m, n])\n");
Chao Liu's avatar
Chao Liu committed
41
42
        printf("arg4: verification (0: no; 1: yes)\n");
        printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
JD's avatar
JD committed
43
44
        printf("arg6: print tensor value (0: no; 1: yes)\n");
        printf("arg7: time kernel (0=n0, 1=yes)\n");
Chao Liu's avatar
Chao Liu committed
45
        printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
ltqin's avatar
ltqin committed
46
        printf("arg14: split k into  mulitiple batch\n");
Chao Liu's avatar
Chao Liu committed
47
48
49
        exit(1);
    }

Chao Liu's avatar
Chao Liu committed
50
51
    const auto data_type       = static_cast<GemmDataType>(std::stoi(argv[2]));
    const auto layout          = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
Chao Liu's avatar
Chao Liu committed
52
53
54
    const bool do_verification = std::stoi(argv[4]);
    const int init_method      = std::stoi(argv[5]);
    const bool do_log          = std::stoi(argv[6]);
JD's avatar
JD committed
55
    const bool time_kernel     = std::stoi(argv[7]);
Chao Liu's avatar
Chao Liu committed
56
57
58
59
60
61
62
63

    const int M = std::stoi(argv[8]);
    const int N = std::stoi(argv[9]);
    const int K = std::stoi(argv[10]);

    const int StrideA = std::stoi(argv[11]);
    const int StrideB = std::stoi(argv[12]);
    const int StrideC = std::stoi(argv[13]);
ltqin's avatar
ltqin committed
64
65
66
    int KBatch        = 1;
    if(argc == 15)
        KBatch = std::stoi(argv[14]);
Chao Liu's avatar
Chao Liu committed
67
68
69
70
71
72

    if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
    {
        ck::profiler::profile_gemm_impl<ck::half_t,
                                        ck::half_t,
                                        ck::half_t,
73
                                        float,
Chao Liu's avatar
Chao Liu committed
74
75
76
77
78
79
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
80
            time_kernel,
Chao Liu's avatar
Chao Liu committed
81
82
83
84
85
            M,
            N,
            K,
            (StrideA < 0) ? K : StrideA,
            (StrideB < 0) ? N : StrideB,
zjing14's avatar
zjing14 committed
86
87
            (StrideC < 0) ? N : StrideC,
            KBatch);
Chao Liu's avatar
Chao Liu committed
88
89
90
91
92
93
    }
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
    {
        ck::profiler::profile_gemm_impl<ck::half_t,
                                        ck::half_t,
                                        ck::half_t,
94
                                        float,
Chao Liu's avatar
Chao Liu committed
95
96
97
98
99
100
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
101
            time_kernel,
Chao Liu's avatar
Chao Liu committed
102
103
104
105
106
            M,
            N,
            K,
            (StrideA < 0) ? K : StrideA,
            (StrideB < 0) ? K : StrideB,
zjing14's avatar
zjing14 committed
107
108
            (StrideC < 0) ? N : StrideC,
            KBatch);
Chao Liu's avatar
Chao Liu committed
109
110
111
112
113
114
    }
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
    {
        ck::profiler::profile_gemm_impl<ck::half_t,
                                        ck::half_t,
                                        ck::half_t,
115
                                        float,
Chao Liu's avatar
Chao Liu committed
116
117
118
119
120
121
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
122
            time_kernel,
Chao Liu's avatar
Chao Liu committed
123
124
125
126
127
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? N : StrideB,
zjing14's avatar
zjing14 committed
128
129
            (StrideC < 0) ? N : StrideC,
            KBatch);
Chao Liu's avatar
Chao Liu committed
130
131
132
133
134
135
    }
    else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
    {
        ck::profiler::profile_gemm_impl<ck::half_t,
                                        ck::half_t,
                                        ck::half_t,
136
                                        float,
Chao Liu's avatar
Chao Liu committed
137
138
139
140
141
142
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
143
            time_kernel,
Chao Liu's avatar
Chao Liu committed
144
145
146
147
148
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? K : StrideB,
zjing14's avatar
zjing14 committed
149
150
            (StrideC < 0) ? N : StrideC,
            KBatch);
Chao Liu's avatar
Chao Liu committed
151
152
153
154
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
    {
        ck::profiler::profile_gemm_impl<float,
155
                                        float,
Chao Liu's avatar
Chao Liu committed
156
157
158
159
160
161
162
163
                                        float,
                                        float,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
164
            time_kernel,
Chao Liu's avatar
Chao Liu committed
165
166
167
168
169
            M,
            N,
            K,
            (StrideA < 0) ? K : StrideA,
            (StrideB < 0) ? N : StrideB,
ltqin's avatar
ltqin committed
170
171
            (StrideC < 0) ? N : StrideC,
            KBatch);
Chao Liu's avatar
Chao Liu committed
172
173
174
175
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
    {
        ck::profiler::profile_gemm_impl<float,
176
                                        float,
Chao Liu's avatar
Chao Liu committed
177
178
179
180
181
182
183
184
                                        float,
                                        float,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
185
            time_kernel,
Chao Liu's avatar
Chao Liu committed
186
187
188
189
190
            M,
            N,
            K,
            (StrideA < 0) ? K : StrideA,
            (StrideB < 0) ? K : StrideB,
ltqin's avatar
ltqin committed
191
192
            (StrideC < 0) ? N : StrideC,
            KBatch);
Chao Liu's avatar
Chao Liu committed
193
194
195
196
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
    {
        ck::profiler::profile_gemm_impl<float,
197
                                        float,
Chao Liu's avatar
Chao Liu committed
198
199
200
201
202
203
204
205
                                        float,
                                        float,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
206
            time_kernel,
Chao Liu's avatar
Chao Liu committed
207
208
209
210
211
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? N : StrideB,
ltqin's avatar
ltqin committed
212
213
            (StrideC < 0) ? N : StrideC,
            KBatch);
Chao Liu's avatar
Chao Liu committed
214
215
216
217
    }
    else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
    {
        ck::profiler::profile_gemm_impl<float,
218
                                        float,
Chao Liu's avatar
Chao Liu committed
219
220
221
222
223
224
225
226
                                        float,
                                        float,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
227
            time_kernel,
Chao Liu's avatar
Chao Liu committed
228
229
230
231
232
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? K : StrideB,
ltqin's avatar
ltqin committed
233
234
            (StrideC < 0) ? N : StrideC,
            KBatch);
Chao Liu's avatar
Chao Liu committed
235
    }
236
237
238
239
240
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
    {
        ck::profiler::profile_gemm_impl<int8_t,
                                        int8_t,
                                        int8_t,
241
                                        int32_t,
242
243
244
245
246
247
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
248
            time_kernel,
249
250
251
252
253
254
255
256
            M,
            N,
            K,
            (StrideA < 0) ? K : StrideA,
            (StrideB < 0) ? N : StrideB,
            (StrideC < 0) ? N : StrideC,
            KBatch);
    }
257
258
259
260
261
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
    {
        ck::profiler::profile_gemm_impl<int8_t,
                                        int8_t,
                                        int8_t,
262
                                        int32_t,
263
264
265
266
267
268
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
269
            time_kernel,
270
271
272
273
274
275
276
277
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? K : StrideB,
            (StrideC < 0) ? N : StrideC,
            KBatch);
    }
278
279
280
281
282
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
    {
        ck::profiler::profile_gemm_impl<int8_t,
                                        int8_t,
                                        int8_t,
283
                                        int32_t,
284
285
286
287
288
289
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
290
            time_kernel,
291
292
293
294
295
296
297
298
299
300
301
302
303
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? N : StrideB,
            (StrideC < 0) ? N : StrideC,
            KBatch);
    }
    else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
    {
        ck::profiler::profile_gemm_impl<int8_t,
                                        int8_t,
                                        int8_t,
304
                                        int32_t,
305
306
307
308
309
310
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
311
            time_kernel,
312
313
314
315
316
317
318
319
320
321
322
323
324
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? K : StrideB,
            (StrideC < 0) ? N : StrideC,
            KBatch);
    }
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
    {
        ck::profiler::profile_gemm_impl<ck::bhalf_t,
                                        ck::bhalf_t,
                                        ck::bhalf_t,
325
                                        float,
326
327
328
329
330
331
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
332
            time_kernel,
333
334
335
336
337
338
339
340
            M,
            N,
            K,
            (StrideA < 0) ? K : StrideA,
            (StrideB < 0) ? N : StrideB,
            (StrideC < 0) ? N : StrideC,
            KBatch);
    }
341
342
343
344
345
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
    {
        ck::profiler::profile_gemm_impl<ck::bhalf_t,
                                        ck::bhalf_t,
                                        ck::bhalf_t,
346
                                        float,
347
348
349
350
351
352
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
353
            time_kernel,
354
355
356
357
358
359
360
361
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? K : StrideB,
            (StrideC < 0) ? N : StrideC,
            KBatch);
    }
362
363
364
365
366
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
    {
        ck::profiler::profile_gemm_impl<ck::bhalf_t,
                                        ck::bhalf_t,
                                        ck::bhalf_t,
367
                                        float,
368
369
370
371
372
373
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
374
            time_kernel,
375
376
377
378
379
380
381
382
383
384
385
386
387
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? N : StrideB,
            (StrideC < 0) ? N : StrideC,
            KBatch);
    }
    else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
    {
        ck::profiler::profile_gemm_impl<ck::bhalf_t,
                                        ck::bhalf_t,
                                        ck::bhalf_t,
388
                                        float,
389
390
391
392
393
394
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::ColumnMajor,
                                        ck::tensor_layout::gemm::RowMajor>(
            do_verification,
            init_method,
            do_log,
JD's avatar
JD committed
395
            time_kernel,
396
397
398
399
400
401
402
403
            M,
            N,
            K,
            (StrideA < 0) ? M : StrideA,
            (StrideB < 0) ? K : StrideB,
            (StrideC < 0) ? N : StrideC,
            KBatch);
    }
Chao Liu's avatar
Chao Liu committed
404
405
406
407
408
    else
    {
        throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
    }

409
    return 0;
Chao Liu's avatar
Chao Liu committed
410
}