"vscode:/vscode.git/clone" did not exist on "2e48584b62b3b5823ef234eeeadab12b5ee6098b"
profile_batched_gemm_reduce_impl.hpp 15.5 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
5
#pragma once

Chao Liu's avatar
Chao Liu committed
6
7
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
8
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
Chao Liu's avatar
Chao Liu committed
9
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
10
#include "ck/utility/reduction_operator.hpp"
Chao Liu's avatar
Chao Liu committed
11

12
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
13
#include "ck/library/utility/check_err.hpp"
14
15
16
17
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
18
#include "ck/library/utility/literals.hpp"
19
20
21
22

namespace ck {
namespace tensor_operation {
namespace device {
23
namespace instance {
24

25
26
27
28
29
30
31
32
33
34
using F32                 = float;
using F16                 = ck::half_t;
using ReducePtrsGlobal    = ck::Tuple<F32*, F32*>;
using Identity            = ck::tensor_operation::element_wise::PassThrough;
using Square              = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps  = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Identity, Identity>;

using DeviceGemmReduceNoOpPtr =
    ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;
35
36

void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
37
    std::vector<DeviceGemmReduceNoOpPtr>&);
38
39

void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
40
    std::vector<DeviceGemmReduceNoOpPtr>&);
41
42

void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
43
    std::vector<DeviceGemmReduceNoOpPtr>&);
44
45

void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
46
    std::vector<DeviceGemmReduceNoOpPtr>&);
47

48
} // namespace instance
49
50
51
52
53
54
55
56
57
58
} // namespace device
} // namespace tensor_operation
} // namespace ck

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
          typename CDataType,
59
          typename ReduceDataType,
60
61
62
63
64
65
          typename ALayout,
          typename BLayout,
          typename CLayout>
bool profile_batched_gemm_reduce_impl(int do_verification,
                                      int init_method,
                                      bool do_log,
JD's avatar
JD committed
66
                                      bool time_kernel,
67
68
69
70
71
72
73
74
75
76
                                      int M,
                                      int N,
                                      int K,
                                      int StrideA,
                                      int StrideB,
                                      int StrideC,
                                      int BatchCount)
{
    bool pass = true;

77
78
    using namespace ck::literals;

79
80
81
82
83
    auto f_host_tensor_descriptor = [](std::size_t batch_count,
                                       std::size_t row,
                                       std::size_t col,
                                       std::size_t stride,
                                       auto layout) {
84
        if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
85
        {
86
            return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz});
87
88
89
        }
        else
        {
90
            return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride});
91
92
93
94
95
96
97
98
        }
    };

    Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{}));
    Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{}));

    Tensor<CDataType> c_g_m_n_host_result(
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
99
100
    Tensor<ReduceDataType> d0_g_m_host_result({BatchCount, M});
    Tensor<ReduceDataType> d1_g_m_host_result({BatchCount, M});
101
102
103

    Tensor<CDataType> c_g_m_n_device_result(
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
104
105
    Tensor<ReduceDataType> d0_g_m_device_result({BatchCount, M});
    Tensor<ReduceDataType> d1_g_m_device_result({BatchCount, M});
106

107
108
109
110
111
    std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
    std::cout << "b_g_k_n: " << b_g_k_n.GetDesc() << std::endl;
    std::cout << "c_g_m_n: " << c_g_m_n_host_result.GetDesc() << std::endl;
    std::cout << "d0_g_m: " << d0_g_m_host_result.GetDesc() << std::endl;
    std::cout << "d1_g_m: " << d1_g_m_host_result.GetDesc() << std::endl;
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    std::size_t num_thread = std::thread::hardware_concurrency();
    switch(init_method)
    {
    case 0: break;
    case 1:
        std::srand(0);
        a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
        b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
        break;
    default:
        std::srand(0);
        a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
        b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
    }

128
129
130
    using AElementOp            = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp            = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp            = ck::tensor_operation::element_wise::PassThrough;
131
132
    using ReduceOp0             = ck::reduce::Add;
    using ReduceOp1             = ck::reduce::Add;
133
134
    using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
    using UnarySquareElementOp  = ck::tensor_operation::element_wise::UnarySquare;
rocking5566's avatar
rocking5566 committed
135

136
137
138
139
140
141
142
143
144
145
146
147
    auto a_element_op                     = AElementOp{};
    auto b_element_op                     = BElementOp{};
    auto c_element_op                     = CElementOp{};
    std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};

    const auto reduce0_op = ReduceOp0{};
    const auto reduce1_op = ReduceOp1{};

    auto passthrough                            = UnaryIdenticElementOp{};
    auto square                                 = UnarySquareElementOp{};
    std::array<void*, 2> reduce_in_element_ops  = {&passthrough, &square};
    std::array<void*, 2> reduce_out_element_ops = {&passthrough, &passthrough};
148
149
150
151
152
153
154

    if(do_verification)
    {
        using ReferenceBatchedGemmInstance =
            ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
                                                             BDataType,
                                                             CDataType,
Anthony Chang's avatar
Anthony Chang committed
155
                                                             float,
156
157
158
159
                                                             AElementOp,
                                                             BElementOp,
                                                             CElementOp>;

160
161
        using ReduceAccDataType = ReduceDataType;

162
163
164
165
166
167
168
169
170
171
172
173
        auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
        auto ref_invoker      = ref_batched_gemm.MakeInvoker();

        auto ref_argument = ref_batched_gemm.MakeArgument(
            a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);

        ref_invoker.Run(ref_argument);

        for(int batch = 0; batch < BatchCount; ++batch)
        {
            for(int m = 0; m < M; ++m)
            {
174
175
                auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
                auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
176
177
178

                for(int n = 0; n < N; ++n)
                {
179
180
181
                    ReduceAccDataType d0_val =
                        ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n));
                    ReduceAccDataType d1_val;
182

183
184
185
                    square(d1_val, d0_val);
                    reduce0_op(reduce0_acc, d0_val);
                    reduce1_op(reduce1_acc, d1_val);
186
187
                }

188
189
                d0_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce0_acc);
                d1_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce1_acc);
190
191
192
193
            }
        }
    }

194
195
196
197
198
    DeviceMem a_device_buf(a_g_m_k.GetMemorySize());
    DeviceMem b_device_buf(b_g_k_n.GetMemorySize());
    DeviceMem c_device_buf(c_g_m_n_device_result.GetMemorySize());
    DeviceMem reduce0_device_buf(d0_g_m_device_result.GetMemorySize());
    DeviceMem reduce1_device_buf(d1_g_m_device_result.GetMemorySize());
199

200
201
    std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
                                      reduce1_device_buf.GetDeviceBuffer()};
rocking5566's avatar
rocking5566 committed
202

203
204
    a_device_buf.ToDevice(a_g_m_k.data());
    b_device_buf.ToDevice(b_g_k_n.data());
205
206

    // add device GEMM instances
207
    std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
208
209
210
211
212
213
214
215

    if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
                 is_same<CDataType, half_t>::value)
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
216
            ck::tensor_operation::device::instance::
217
218
219
220
221
222
223
                add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
                    gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
224
            ck::tensor_operation::device::instance::
225
226
227
228
229
230
231
                add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
                    gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
232
            ck::tensor_operation::device::instance::
233
234
235
236
237
238
239
                add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
                    gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
240
            ck::tensor_operation::device::instance::
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
                add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
                    gemm_ptrs);
        }
    }

    if(gemm_ptrs.size() <= 0)
    {
        throw std::runtime_error("wrong! no device GEMM instance found");
    }

    std::string best_gemm_name;
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

    // profile device GEMM instances
    for(auto& gemm_ptr : gemm_ptrs)
    {
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
                                                          b_device_buf.GetDeviceBuffer(),
                                                          nullptr,
                                                          {},
                                                          c_device_buf.GetDeviceBuffer(),
                                                          p_reduces,
                                                          M,
                                                          N,
                                                          K,
                                                          StrideA,
                                                          StrideB,
                                                          StrideC,
                                                          {},
                                                          gemm_element_ops,
                                                          {},
                                                          reduce_in_element_ops,
                                                          reduce_out_element_ops,
                                                          BatchCount);
277
278
279
280
281

        auto invoker_ptr = gemm_ptr->MakeInvokerPointer();

        if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
        {
JD's avatar
JD committed
282
            // init DO, D1 to 0
283
284
            reduce0_device_buf.SetZero();
            reduce1_device_buf.SetZero();
285

JD's avatar
JD committed
286
287
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
288
289
290

            std::string gemm_name = gemm_ptr->GetTypeString();

291
            std::size_t flop      = 2_uz * BatchCount * M * N * K;
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
            std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K +
                                    sizeof(BDataType) * BatchCount * K * N +
                                    sizeof(CDataType) * BatchCount * M * N;

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

            std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
                      << " GB/s, " << gemm_name << std::endl;

            if(tflops > best_tflops)
            {
                best_gemm_name  = gemm_name;
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
313
314
315
                c_device_buf.FromDevice(c_g_m_n_device_result.data());
                reduce0_device_buf.FromDevice(d0_g_m_device_result.data());
                reduce1_device_buf.FromDevice(d1_g_m_device_result.data());
316

317
318
319
                bool c_error  = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
                bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result);
                bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result);
320
321
322
323

                pass = pass && (c_error == true);
                pass = pass && (d0_error == true);
                pass = pass && (d1_error == true);
324
325
326

                if(do_log)
                {
327
328
329
                    LogRangeAsType<float>(std::cout << "a : ", a_g_m_k, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "b: ", b_g_k_n, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result, ",")
330
                        << std::endl;
331
                    LogRangeAsType<float>(std::cout << "c_device: ", c_g_m_n_device_result, ",")
332
                        << std::endl;
333
                    LogRangeAsType<float>(std::cout << "d0_host: ", d0_g_m_host_result, ",")
334
                        << std::endl;
335
                    LogRangeAsType<float>(std::cout << "d0_device: ", d0_g_m_device_result, ",")
336
                        << std::endl;
337
                    LogRangeAsType<float>(std::cout << "d1_host: ", d1_g_m_host_result, ",")
338
                        << std::endl;
339
                    LogRangeAsType<float>(std::cout << "d1_device: ", d1_g_m_device_result, ",")
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                        << std::endl;
                }
            }
        }
        else
        {
            std::cout << "does not support this GEMM problem" << std::endl;
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
              << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;

    return pass;
}

} // namespace profiler
} // namespace ck