profile_gemm_impl.hpp 29.4 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2

Chao Liu's avatar
Chao Liu committed
3
#include <iomanip>
4
5
#include <iostream>
#include <typeinfo>
6

Chao Liu's avatar
Chao Liu committed
7
8
9
10
11
12
13
14
15
16
17
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
18
19
20
21
22
23

namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {

ltqin's avatar
ltqin committed
24
25
26
27
28
29
30
31
32
33
using DeviceGemmNoOpPtr =
    ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
                                                ck::tensor_operation::element_wise::PassThrough,
                                                ck::tensor_operation::element_wise::PassThrough>;

void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

34
35
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
36
37
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
38
39
40
41
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
42

Chao Liu's avatar
Chao Liu committed
43
44
45
46
47
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

Jianfeng Yan's avatar
Jianfeng Yan committed
48
49
50
51
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
52

Chao Liu's avatar
Chao Liu committed
53
54
55
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);

ltqin's avatar
ltqin committed
56
57
58
59
60
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

61
62
63
64
65
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

ltqin's avatar
ltqin committed
66
67
68
69
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
70

zjing14's avatar
zjing14 committed
71
72
73
74
75
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

Jianfeng Yan's avatar
Jianfeng Yan committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

91
92
93
94
95
96
97
98
99
100
101
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
          typename CDataType,
102
          typename AccDataType,
103
104
105
          typename ALayout,
          typename BLayout,
          typename CLayout>
Chao Liu's avatar
Chao Liu committed
106
107
108
void profile_gemm_impl(int do_verification,
                       int init_method,
                       bool do_log,
JD's avatar
JD committed
109
                       bool time_kernel,
Chao Liu's avatar
Chao Liu committed
110
111
112
113
114
                       int M,
                       int N,
                       int K,
                       int StrideA,
                       int StrideB,
ltqin's avatar
ltqin committed
115
                       int StrideC,
zjing14's avatar
zjing14 committed
116
                       int KBatch)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
{
    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({stride, 1}));
            }
            else
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({1, stride}));
            }
        };

    Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
    Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
    Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));

    std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
    std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
138
    std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
139

140
    std::size_t num_thread = 1;
141
142
    switch(init_method)
    {
Jianfeng Yan's avatar
Jianfeng Yan committed
143
144
145
146
147
    // case 0: break;
    case 0:
        a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{}, num_thread);
        b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{}, num_thread);
        break;
148
    case 1:
ltqin's avatar
ltqin committed
149
150
        a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
        b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
151
152
        break;
    default:
ltqin's avatar
ltqin committed
153
154
        a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
        b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
155
    }
Chao Liu's avatar
Chao Liu committed
156

ltqin's avatar
ltqin committed
157
158
    // set zero to c_device_buf
    c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
159

Chao Liu's avatar
Chao Liu committed
160
161
162
163
164
165
166
167
    using AElementOp = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto a_element_op = AElementOp{};
    const auto b_element_op = BElementOp{};
    const auto c_element_op = CElementOp{};

168
169
170
171
172
173
174
175
176
    DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
    DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
    DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());

    a_device_buf.ToDevice(a_m_k.mData.data());
    b_device_buf.ToDevice(b_k_n.mData.data());
    c_device_buf.ToDevice(c_m_n_device_result.mData.data());

    // add device GEMM instances
Chao Liu's avatar
Chao Liu committed
177
    std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
178

ltqin's avatar
ltqin committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
                 is_same<CDataType, float>::value)
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
195

Jianfeng Yan's avatar
Jianfeng Yan committed
196
197
198
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);

199
200
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
ltqin's avatar
ltqin committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
            }
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
216

Jianfeng Yan's avatar
Jianfeng Yan committed
217
218
219
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);

220
221
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
ltqin's avatar
ltqin committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            }
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
237

Jianfeng Yan's avatar
Jianfeng Yan committed
238
239
240
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);

241
242
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
ltqin's avatar
ltqin committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            }
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
258

Jianfeng Yan's avatar
Jianfeng Yan committed
259
260
261
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);

262
263
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
ltqin's avatar
ltqin committed
264
265
266
267
268
269
270
271
272
273
            }
        }
    }
    else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
                      is_same<CDataType, half_t>::value)
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
zjing14's avatar
zjing14 committed
274
275
276
277
278
279
280
281
282
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
283

Jianfeng Yan's avatar
Jianfeng Yan committed
284
285
286
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);

zjing14's avatar
zjing14 committed
287
288
289
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
            }
ltqin's avatar
ltqin committed
290
291
292
293
294
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
zjing14's avatar
zjing14 committed
295
296
297
298
299
300
301
302
303
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
304

Jianfeng Yan's avatar
Jianfeng Yan committed
305
306
307
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);

zjing14's avatar
zjing14 committed
308
309
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
310

zjing14's avatar
zjing14 committed
311
312
313
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
            }
ltqin's avatar
ltqin committed
314
315
316
317
318
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
zjing14's avatar
zjing14 committed
319
320
321
322
323
324
325
326
327
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
328

Jianfeng Yan's avatar
Jianfeng Yan committed
329
330
331
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);

zjing14's avatar
zjing14 committed
332
333
334
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
            }
ltqin's avatar
ltqin committed
335
336
337
338
339
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
zjing14's avatar
zjing14 committed
340
341
342
343
344
345
346
347
348
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
349

Jianfeng Yan's avatar
Jianfeng Yan committed
350
351
352
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);

zjing14's avatar
zjing14 committed
353
354
355
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
            }
ltqin's avatar
ltqin committed
356
357
        }
    }
358
359
360
361
362
    else if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
                      is_same<BDataType, ck::bhalf_t>::value &&
                      is_same<CDataType, ck::bhalf_t>::value)
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
363
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
364
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
365
366
367
368
369
370
371
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
372
373
374
375
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs);
        }
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs);
        }
390
391
392
393
394
    }
    else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
                      is_same<CDataType, int8_t>::value)
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
395
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
396
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
397
398
        {
            ck::tensor_operation::device::device_gemm_instance::
Jianfeng Yan's avatar
Jianfeng Yan committed
399
400
401
402
                add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs);

            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs);
403
404
405
406
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
407
408
        {
            ck::tensor_operation::device::device_gemm_instance::
Jianfeng Yan's avatar
Jianfeng Yan committed
409
410
411
412
                add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs);

            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs);
413
        }
414
415
416
417
418
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            ck::tensor_operation::device::device_gemm_instance::
Jianfeng Yan's avatar
Jianfeng Yan committed
419
420
421
422
                add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemm_ptrs);

            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemm_ptrs);
423
424
425
426
427
428
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            ck::tensor_operation::device::device_gemm_instance::
Jianfeng Yan's avatar
Jianfeng Yan committed
429
430
431
432
                add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemm_ptrs);

            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemm_ptrs);
433
        }
434
    }
435
436
437
438
439
440

    if(gemm_ptrs.size() <= 0)
    {
        throw std::runtime_error("wrong! no device GEMM instance found");
    }

Chao Liu's avatar
Chao Liu committed
441
    std::string best_gemm_name;
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

    // profile device GEMM instances
    for(auto& gemm_ptr : gemm_ptrs)
    {
        auto argument_ptr =
            gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
                                          static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
                                          static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
                                          M,
                                          N,
                                          K,
                                          StrideA,
                                          StrideB,
Chao Liu's avatar
Chao Liu committed
458
459
460
                                          StrideC,
                                          ck::tensor_operation::element_wise::PassThrough{},
                                          ck::tensor_operation::element_wise::PassThrough{},
ltqin's avatar
ltqin committed
461
462
                                          ck::tensor_operation::element_wise::PassThrough{},
                                          KBatch);
463
464
465
466
467

        auto invoker_ptr = gemm_ptr->MakeInvokerPointer();

        if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
        {
468
469
470
471
            // re-init C to zero before profiling next kernel
            c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
            c_device_buf.ToDevice(c_m_n_device_result.mData.data());

Chao Liu's avatar
Chao Liu committed
472
473
            std::string gemm_name = gemm_ptr->GetTypeString();

JD's avatar
JD committed
474
475
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
476
477

            std::size_t flop = std::size_t(2) * M * N * K;
Chao Liu's avatar
Chao Liu committed
478

479
            std::size_t num_btype =
Chao Liu's avatar
Chao Liu committed
480
                sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
481
482
483
484
485

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

Chao Liu's avatar
Chao Liu committed
486
487
            std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
                      << gb_per_sec << " GB/s, " << gemm_name << std::endl;
488
489
490

            if(tflops > best_tflops)
            {
Chao Liu's avatar
Chao Liu committed
491
                best_gemm_name  = gemm_name;
492
493
494
495
496
497
498
499
500
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
                c_device_buf.FromDevice(c_m_n_device_result.mData.data());

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
                if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
                             is_same<BDataType, ck::bhalf_t>::value &&
                             is_same<CDataType, ck::bhalf_t>::value)
                {
                    Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
                    Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
                    Tensor<float> c_m_n_host_result(
                        f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
                    Tensor<float> c_m_n_device_f32_result(
                        f_host_tensor_descriptor(M, N, StrideC, CLayout{}));

                    bf16_to_f32_(a_m_k, a_f32_m_k);
                    bf16_to_f32_(b_k_n, b_f32_k_n);
                    bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);

516
517
518
519
520
521
522
523
                    using ReferenceGemmInstance =
                        ck::tensor_operation::host::ReferenceGemm<float,
                                                                  float,
                                                                  float,
                                                                  float,
                                                                  AElementOp,
                                                                  BElementOp,
                                                                  CElementOp>;
524
525
526
527
528
529
530
531
532
533
534
535
536

                    auto ref_gemm    = ReferenceGemmInstance{};
                    auto ref_invoker = ref_gemm.MakeInvoker();

                    auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k,
                                                              b_f32_k_n,
                                                              c_m_n_host_result,
                                                              a_element_op,
                                                              b_element_op,
                                                              c_element_op);

                    ref_invoker.Run(ref_argument);

537
                    ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData);
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

                    if(do_log)
                    {
                        LogRangeAsType<float>(
                            std::cout << "c_host  : ", c_m_n_host_result.mData, ",")
                            << std::endl;
                    }
                }
                else
                {
                    Tensor<CDataType> c_m_n_host_result(
                        f_host_tensor_descriptor(M, N, StrideC, CLayout{}));

                    using ReferenceGemmInstance =
                        ck::tensor_operation::host::ReferenceGemm<ADataType,
                                                                  BDataType,
                                                                  CDataType,
555
                                                                  AccDataType,
556
557
558
559
560
561
562
563
564
565
566
                                                                  AElementOp,
                                                                  BElementOp,
                                                                  CElementOp>;

                    auto ref_gemm    = ReferenceGemmInstance{};
                    auto ref_invoker = ref_gemm.MakeInvoker();

                    auto ref_argument = ref_gemm.MakeArgument(
                        a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);

                    ref_invoker.Run(ref_argument);
567
                    ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
568
569
570
571
572
573
574
575

                    if(do_log)
                    {
                        LogRangeAsType<float>(
                            std::cout << "c_host  : ", c_m_n_host_result.mData, ",")
                            << std::endl;
                    }
                }
576
577
578
579
580
581
582
583
584
585
586
587

                if(do_log)
                {
                    LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
                        << std::endl;
                }
            }
        }
        else
        {
Jianfeng Yan's avatar
Jianfeng Yan committed
588
589
            std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem"
                      << std::endl;
590
591
592
        }
    }

593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    if constexpr(is_same<CDataType, float>::value)
    {
        std::cout << "Best Perf for datatype = f32";
    }
    else if constexpr(is_same<CDataType, half_t>::value)
    {
        std::cout << "Best Perf for datatype = f16";
    }
    else if constexpr(is_same<CDataType, bhalf_t>::value)
    {
        std::cout << "Best Perf for datatype = bf16";
    }
    else if constexpr(is_same<CDataType, int8_t>::value)
    {
        std::cout << "Best Perf for datatype = int8";
    }

    if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
    {
        std::cout << " ALayout =  RowMajor";
    }
    else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
    {
        std::cout << " ALayout =  ColumnMajor";
    }

    if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
    {
        std::cout << " BLayout =  RowMajor";
    }
    else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
    {
        std::cout << " BLayout =  ColumnMajor";
    }

    std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
              << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time
              << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
              << best_gemm_name << std::endl;
632
633
634
635
}

} // namespace profiler
} // namespace ck