profile_gemm_impl.hpp 24.6 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include <iomanip>
3
4

#include "check_err.hpp"
Chao Liu's avatar
Chao Liu committed
5
6
7
8
9
10
11
12
13
14
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "reference_gemm.hpp"
15
16
17
18
19
20

namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {

ltqin's avatar
ltqin committed
21
22
23
24
25
26
27
28
29
30
using DeviceGemmNoOpPtr =
    ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
                                                ck::tensor_operation::element_wise::PassThrough,
                                                ck::tensor_operation::element_wise::PassThrough>;

void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

31
32
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
33
34
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
35
36
37
38
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
39

Chao Liu's avatar
Chao Liu committed
40
41
42
43
44
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

45
46
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
47
48
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
49
50
51
52
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);
53

Chao Liu's avatar
Chao Liu committed
54
55
56
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
    std::vector<DeviceGemmNoOpPtr>&);

ltqin's avatar
ltqin committed
57
58
59
60
61
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

62
63
64
65
66
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

ltqin's avatar
ltqin committed
67
68
69
70
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
71

zjing14's avatar
zjing14 committed
72
73
74
75
76
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);

77
78
79
80
81
82
83
84
85
86
87
88
89
90
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
          typename CDataType,
          typename ALayout,
          typename BLayout,
          typename CLayout>
Chao Liu's avatar
Chao Liu committed
91
92
93
void profile_gemm_impl(int do_verification,
                       int init_method,
                       bool do_log,
JD's avatar
JD committed
94
                       bool time_kernel,
Chao Liu's avatar
Chao Liu committed
95
96
97
98
99
                       int M,
                       int N,
                       int K,
                       int StrideA,
                       int StrideB,
ltqin's avatar
ltqin committed
100
                       int StrideC,
zjing14's avatar
zjing14 committed
101
                       int KBatch)
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
{
    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({stride, 1}));
            }
            else
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({1, stride}));
            }
        };

    Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
    Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
    Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));

    std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
    std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
123
    std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
124

125
    std::size_t num_thread = 1;
126
127
128
129
    switch(init_method)
    {
    case 0: break;
    case 1:
ltqin's avatar
ltqin committed
130
131
        a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
        b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
132
133
        break;
    default:
ltqin's avatar
ltqin committed
134
135
        a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
        b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
136
    }
Chao Liu's avatar
Chao Liu committed
137

ltqin's avatar
ltqin committed
138
139
    // set zero to c_device_buf
    c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
140

Chao Liu's avatar
Chao Liu committed
141
142
143
144
145
146
147
148
    using AElementOp = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto a_element_op = AElementOp{};
    const auto b_element_op = BElementOp{};
    const auto c_element_op = CElementOp{};

149
150
151
152
153
154
155
156
157
    DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
    DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
    DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());

    a_device_buf.ToDevice(a_m_k.mData.data());
    b_device_buf.ToDevice(b_k_n.mData.data());
    c_device_buf.ToDevice(c_m_n_device_result.mData.data());

    // add device GEMM instances
Chao Liu's avatar
Chao Liu committed
158
    std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
159

ltqin's avatar
ltqin committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
                 is_same<CDataType, float>::value)
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
176
177
178

                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
ltqin's avatar
ltqin committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
            }
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
194
195
196

                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
ltqin's avatar
ltqin committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
            }
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
212
213
214

                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
ltqin's avatar
ltqin committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            }
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
230
231
232

                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
ltqin's avatar
ltqin committed
233
234
235
236
237
238
239
240
241
242
            }
        }
    }
    else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
                      is_same<CDataType, half_t>::value)
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
zjing14's avatar
zjing14 committed
243
244
245
246
247
248
249
250
251
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
252

zjing14's avatar
zjing14 committed
253
254
255
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
            }
ltqin's avatar
ltqin committed
256
257
258
259
260
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
zjing14's avatar
zjing14 committed
261
262
263
264
265
266
267
268
269
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
270

zjing14's avatar
zjing14 committed
271
272
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
273

zjing14's avatar
zjing14 committed
274
275
276
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
            }
ltqin's avatar
ltqin committed
277
278
279
280
281
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
zjing14's avatar
zjing14 committed
282
283
284
285
286
287
288
289
290
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
291

zjing14's avatar
zjing14 committed
292
293
294
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
            }
ltqin's avatar
ltqin committed
295
296
297
298
299
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
zjing14's avatar
zjing14 committed
300
301
302
303
304
305
306
307
308
            if(KBatch > 1)
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
            }
            else
            {
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
309

zjing14's avatar
zjing14 committed
310
311
312
                ck::tensor_operation::device::device_gemm_instance::
                    add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
            }
ltqin's avatar
ltqin committed
313
314
        }
    }
315
316
317
318
319
    else if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
                      is_same<BDataType, ck::bhalf_t>::value &&
                      is_same<CDataType, ck::bhalf_t>::value)
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
320
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
321
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
322
323
324
325
326
327
328
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
329
330
331
332
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs);
        }
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs);
        }
347
348
349
350
351
    }
    else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
                      is_same<CDataType, int8_t>::value)
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
352
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
353
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
354
355
356
357
358
359
360
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
361
362
363
364
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs);
        }
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            ck::tensor_operation::device::device_gemm_instance::
                add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs);
        }
379
    }
380
381
382
383
384
385

    if(gemm_ptrs.size() <= 0)
    {
        throw std::runtime_error("wrong! no device GEMM instance found");
    }

Chao Liu's avatar
Chao Liu committed
386
    std::string best_gemm_name;
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

    // profile device GEMM instances
    for(auto& gemm_ptr : gemm_ptrs)
    {
        auto argument_ptr =
            gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
                                          static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
                                          static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
                                          M,
                                          N,
                                          K,
                                          StrideA,
                                          StrideB,
Chao Liu's avatar
Chao Liu committed
403
404
405
                                          StrideC,
                                          ck::tensor_operation::element_wise::PassThrough{},
                                          ck::tensor_operation::element_wise::PassThrough{},
ltqin's avatar
ltqin committed
406
407
                                          ck::tensor_operation::element_wise::PassThrough{},
                                          KBatch);
408
409
410
411
412

        auto invoker_ptr = gemm_ptr->MakeInvokerPointer();

        if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
        {
413
414
415
416
            // re-init C to zero before profiling next kernel
            c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
            c_device_buf.ToDevice(c_m_n_device_result.mData.data());

Chao Liu's avatar
Chao Liu committed
417
418
            std::string gemm_name = gemm_ptr->GetTypeString();

JD's avatar
JD committed
419
420
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
421
422

            std::size_t flop = std::size_t(2) * M * N * K;
Chao Liu's avatar
Chao Liu committed
423

424
            std::size_t num_btype =
Chao Liu's avatar
Chao Liu committed
425
                sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
426
427
428
429
430

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

Chao Liu's avatar
Chao Liu committed
431
432
            std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
                      << gb_per_sec << " GB/s, " << gemm_name << std::endl;
433
434
435

            if(tflops > best_tflops)
            {
Chao Liu's avatar
Chao Liu committed
436
                best_gemm_name  = gemm_name;
437
438
439
440
441
442
443
444
445
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
                c_device_buf.FromDevice(c_m_n_device_result.mData.data());

446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
                if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
                             is_same<BDataType, ck::bhalf_t>::value &&
                             is_same<CDataType, ck::bhalf_t>::value)
                {
                    Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
                    Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
                    Tensor<float> c_m_n_host_result(
                        f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
                    Tensor<float> c_m_n_device_f32_result(
                        f_host_tensor_descriptor(M, N, StrideC, CLayout{}));

                    bf16_to_f32_(a_m_k, a_f32_m_k);
                    bf16_to_f32_(b_k_n, b_f32_k_n);
                    bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);

                    using ReferenceGemmInstance = ck::tensor_operation::host::
                        ReferenceGemm<float, float, float, AElementOp, BElementOp, CElementOp>;

                    auto ref_gemm    = ReferenceGemmInstance{};
                    auto ref_invoker = ref_gemm.MakeInvoker();

                    auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k,
                                                              b_f32_k_n,
                                                              c_m_n_host_result,
                                                              a_element_op,
                                                              b_element_op,
                                                              c_element_op);

                    ref_invoker.Run(ref_argument);

476
                    ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData);
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504

                    if(do_log)
                    {
                        LogRangeAsType<float>(
                            std::cout << "c_host  : ", c_m_n_host_result.mData, ",")
                            << std::endl;
                    }
                }
                else
                {
                    Tensor<CDataType> c_m_n_host_result(
                        f_host_tensor_descriptor(M, N, StrideC, CLayout{}));

                    using ReferenceGemmInstance =
                        ck::tensor_operation::host::ReferenceGemm<ADataType,
                                                                  BDataType,
                                                                  CDataType,
                                                                  AElementOp,
                                                                  BElementOp,
                                                                  CElementOp>;

                    auto ref_gemm    = ReferenceGemmInstance{};
                    auto ref_invoker = ref_gemm.MakeInvoker();

                    auto ref_argument = ref_gemm.MakeArgument(
                        a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);

                    ref_invoker.Run(ref_argument);
505
                    ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
506
507
508
509
510
511
512
513

                    if(do_log)
                    {
                        LogRangeAsType<float>(
                            std::cout << "c_host  : ", c_m_n_host_result.mData, ",")
                            << std::endl;
                    }
                }
514
515
516
517
518
519
520
521
522
523
524
525

                if(do_log)
                {
                    LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
                        << std::endl;
                }
            }
        }
        else
        {
Chao Liu's avatar
Chao Liu committed
526
            std::cout << "does not support this GEMM problem" << std::endl;
527
528
529
530
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
Chao Liu's avatar
Chao Liu committed
531
              << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
532
533
534
535
}

} // namespace profiler
} // namespace ck