run_gemm_example.inc 15.6 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5

#pragma once

6
7
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
    if constexpr(std::is_same_v<DataType, float>)
    {
        return 1e-3;
    }
    else if constexpr(std::is_same_v<DataType, double>)
    {
        return 1e-6;
    }
    else if constexpr(std::is_same_v<DataType, ck::half_t>)
    {
        return 1e-3;
    }
    else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
    {
        return 5e-2;
    }
    else if constexpr(std::is_same_v<DataType, int32_t>)
    {
        return 1e-1;
    }
    else if constexpr(std::is_same_v<DataType, int8_t>)
    {
        return 1e-1;
    }
    else if constexpr(std::is_same_v<DataType, ck::f8_t>)
    {
37
        return 2e-1;
38
39
40
    }
    else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
    {
41
        return 2e-1;
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    }
    else
    {
        return 1e-3;
    }
}

template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
    if constexpr(std::is_same_v<DataType, float>)
    {
        return 1e-3;
    }
    else if constexpr(std::is_same_v<DataType, double>)
    {
        return 1e-6;
    }
    else if constexpr(std::is_same_v<DataType, ck::half_t>)
    {
        return 1e-3;
    }
    else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
    {
        return 5e-2;
    }
    else if constexpr(std::is_same_v<DataType, int32_t>)
    {
        return 1e-1;
    }
    else if constexpr(std::is_same_v<DataType, int8_t>)
    {
        return 1e-1;
    }
    else if constexpr(std::is_same_v<DataType, ck::f8_t>)
    {
78
        return 2e-1;
79
80
81
    }
    else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
    {
82
        return 2e-1;
83
84
85
86
87
88
89
    }
    else
    {
        return 1e-3;
    }
}

90
91
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
92
93
94
95
96
97
98
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
    static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif

    using namespace ck::literals;

99
100
101
102
103
104
    auto M       = problem_size.M;
    auto N       = problem_size.N;
    auto K       = problem_size.K;
    auto StrideA = problem_size.StrideA;
    auto StrideB = problem_size.StrideB;
    auto StrideC = problem_size.StrideC;
105
106
107
108
109
110
111
112
113
114
115
116
117

    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
            {
                return HostTensorDescriptor({row, col}, {stride, 1_uz});
            }
            else
            {
                return HostTensorDescriptor({row, col}, {1_uz, stride});
            }
        };

118
    auto f_get_default_stride =
119
120
        [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
            if(stride == -1)
121
            {
122
                // give a chance if stride is -1, return a default packed stride
123
124
                if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
                {
125
                    return static_cast<std::size_t>(col);
126
127
128
                }
                else
                {
129
                    return static_cast<std::size_t>(row);
130
131
132
                }
            }
            else
133
                return static_cast<std::size_t>(stride);
134
135
136
137
138
139
        };

    StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
    StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
    StrideC = f_get_default_stride(M, N, StrideC, CLayout{});

140
141
142
143
144
    Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
    Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));

    switch(config.init_method)
    {
145
    case 0:
146
147
        ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.f)}(a_m_k);
        ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.f)}(b_k_n);
148
        break;
149
    case 1:
150
151
        ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
        ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
152
        break;
zjing14's avatar
zjing14 committed
153
154
155
156
157
158
159
160
161
    case 2:
        ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
        ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
        break;
    case 3:
        ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
        ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
        break;
    case 4:
162
        ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
zjing14's avatar
zjing14 committed
163
164
165
166
167
168
        ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
        break;
    case 5:
        ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
        ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n);
        break;
169
170
171
172
173
174
175
176
    case 6:
        a_m_k.GenerateTensorValue(GeneratorTensor_PI<ADataType>{});
        b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
        break;
    case 7:
        a_m_k.GenerateTensorValue(GeneratorTensor_PI_A<ADataType>{});
        b_k_n.GenerateTensorValue(GeneratorTensor_PI_B<BDataType>{});
        break;
177
    default:
178
179
        ck::utils::FillUniformDistribution<ADataType>{-0.1f, 0.1f}(a_m_k);
        ck::utils::FillUniformDistribution<BDataType>{-0.1f, 0.1f}(b_k_n);
180
181
182
    }

    Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
183
    Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
184
    Tensor<CDataType> c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
185
186
187
188
189
190

    std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
    std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
    std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;

#ifdef BUILD_INT4_EXAMPLE
191
192
193
194
195
    DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
    DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
    DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
                               c_m_n_device_result.mDesc.GetElementSpaceSize());

196
197
198
199
200
201
    const Tensor<KernelADataType> a_m_k_converted(a_m_k);
    const Tensor<KernelBDataType> b_k_n_converted(b_k_n);

    a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
    b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
202
203
204
    DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
    DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
    DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
205
206
    DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) *
                                   c_m_n_device_ref_result.mDesc.GetElementSpaceSize());
207

208
209
210
    a_m_k_device_buf.ToDevice(a_m_k.mData.data());
    b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
211
    DeviceMem workspace;
212
213
214
215
216

    auto a_element_op = AElementOp{};
    auto b_element_op = BElementOp{};
    auto c_element_op = CElementOp{};

217
218
219
220
221
222
223
224
225
226
    using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK<ALayout,
                                                                        BLayout,
                                                                        CLayout,
                                                                        ADataType,
                                                                        BDataType,
                                                                        CDataType,
                                                                        AElementOp,
                                                                        BElementOp,
                                                                        CElementOp>;

227
    // do GEMM
228
229
230
231
232
233
234
235
    auto gemm      = DeviceGemmInstance{};
    auto invoker   = gemm.MakeInvoker();
    float ave_time = 0;

    if constexpr(std::is_same<ProblemType, ProblemSize>::value &&
                 !std::is_base_of<BaseStreamK, DeviceGemmInstance>::value)
    {
        auto argument = gemm.MakeArgument(
236
#ifdef BUILD_INT4_EXAMPLE
237
238
239
            static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
            static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
            static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
240
#else
241
242
243
            static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
            static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
            static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
244
#endif
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            M,
            N,
            K,
            StrideA,
            StrideB,
            StrideC,
            a_element_op,
            b_element_op,
            c_element_op);

        if(!gemm.IsSupportedArgument(argument))
        {
            std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;

259
            return false;
260
261
262
263
264
265
        }

        ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
    }
    else if constexpr(std::is_same<ProblemType, ProblemSizeStreamK>::value &&
                      std::is_base_of<BaseStreamK, DeviceGemmInstance>::value)
266
    {
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
            static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
            static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
            static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
            static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
            static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
            static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
            M,
            N,
            K,
            StrideA,
            StrideB,
            StrideC,
            a_element_op,
            b_element_op,
            c_element_op,
            problem_size.NumSKBlocks);

        if(!gemm.IsSupportedArgument(argument))
        {
            std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;

292
            return false;
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        }

        std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
        if(workspace_size != 0)
        {
            workspace.Realloc(workspace_size);
            gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer());
        }

        ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});

#if 0
        // TODO!!!!!
        if(workspace_size != 0){
            float * ws_ptr = reinterpret_cast<float*>(malloc(workspace_size));
            size_t ws_dwords = workspace_size / sizeof(float);
            workspace.FromDevice(ws_ptr);

            for(size_t i = 0; i < ws_dwords; i++) {
                uint32_t rere = reinterpret_cast<uint32_t*>(ws_ptr)[i];
                printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere);
            }
            free(ws_ptr);
        }
#endif
318
    }
319
320
321
322
323
324
    else
    {
        // When the Problem Type and Problem Size does not fit.

        std::cerr << gemm.GetTypeString() << ": the instance does not support the problem config."
                  << std::endl;
325
        return false;
326
    }
327

328
329
330
331
332
    if(config.time_kernel)
    {
        std::size_t flop = 2_uz * M * N * K;
        std::size_t num_btype =
            sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
333

334
        float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
335

336
        float gb_per_sec = num_btype / 1.E6 / ave_time;
337

338
339
340
341
342
343
344
        std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
                  << " GB/s, " << gemm.GetTypeString() << std::endl;
    }
    else
    {
        std::cout << "FINISHED: " << gemm.GetTypeString() << std::endl;
    }
345

346
347
    bool pass = true;

348
    if((config.do_verification == 1) || (config.do_verification == 3))
349
    {
350
        // CPU verification
351
352
353
354
355
356
        auto ref_gemm    = ReferenceGemmInstance{};
        auto ref_invoker = ref_gemm.MakeInvoker();

        auto ref_argument = ref_gemm.MakeArgument(
            a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);

357
        std::cout << "Running verification on CPU." << std::endl;
358
359
360
        ref_invoker.Run(ref_argument);

#ifdef BUILD_INT4_EXAMPLE
361
362
363
364
365
        Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);

        c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());

        c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
366

367
        return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
368
#else
369
370
        c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());

371
372
373
374
375
        pass = ck::utils::check_err(c_m_n_device_result,
                                    c_m_n_host_result,
                                    "Error: Incorrect results!",
                                    get_rtol<CDataType>(),
                                    get_atol<CDataType>());
376
#endif
377
378
        if(pass)
            std::cout << "Verification on CPU: PASS" << std::endl;
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393

        if(config.init_method == 6 || config.init_method == 7)
        {
            std::cout << std::fixed << std::setprecision(16);

            AccDataType d = ck::type_convert<AccDataType>(c_m_n_device_result(0, 10));
            AccDataType h = ck::type_convert<AccDataType>(c_m_n_host_result(10, 0));
            std::cout << "device result: " << d << std::endl;
            std::cout << "host result: " << h << std::endl;
            std::cout << "expected result: " << M_PI << std::endl;
            std::cout << "device - host: " << std::abs(d - h) << std::endl;
            std::cout << "device - expected: " << std::abs(d - M_PI) << std::endl;
            std::cout << "atol: " << get_atol<CDataType>() << std::endl;
            std::cout << std::endl << std::endl;
        }
394
    }
395

396
397
    if((config.do_verification == 2) || (config.do_verification == 3))
    {
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        // GPU verification
        auto ref_gemm_gpu    = ReferenceGemmInstanceGPU{};
        auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker();

        auto ref_argument_gpu = ref_gemm_gpu.MakeArgument(
            static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
            static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
            static_cast<CDataType*>(c_m_n_device_ref_buf.GetDeviceBuffer()),
            M,
            N,
            K,
            a_element_op,
            b_element_op,
            c_element_op);

        std::cout << "Running verification on GPU." << std::endl;
        ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{});

        c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data());
        c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());

419
420
421
422
423
424
425
        pass = ck::utils::check_err(c_m_n_device_result,
                                    c_m_n_device_ref_result,
                                    "Error: Incorrect results!",
                                    get_rtol<CDataType>(),
                                    get_atol<CDataType>());
        if(pass)
            std::cout << "Verification on GPU: PASS" << std::endl;
426
427
    }

428
    return pass;
429
430
431
432
433
434
435
}

bool run_gemm_example(int argc, char* argv[])
{
    ProblemSize problem_size;
    ExecutionConfig config;

436
    return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
437
}
438
439
440
441
442
443

bool run_gemm_streamk_example(int argc, char* argv[])
{
    ProblemSizeStreamK problem_size;
    ExecutionConfig config;

444
    return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
445
}