profile_grouped_gemm_impl.hpp 12.6 KB
Newer Older
Jing Zhang's avatar
Jing Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#pragma once
#include <iomanip>
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "reference_gemm.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace device_grouped_gemm_instance {

Jing Zhang's avatar
Jing Zhang committed
19
20
21
22
23
24
25
using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr<
    ck::tensor_operation::element_wise::PassThrough,
    ck::tensor_operation::element_wise::PassThrough,
    ck::tensor_operation::element_wise::PassThrough>;

void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
    std::vector<DeviceGroupedGemmNoOpPtr>&);
Jing Zhang's avatar
Jing Zhang committed
26
27
28
29
30
31
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
    std::vector<DeviceGroupedGemmNoOpPtr>&);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
    std::vector<DeviceGroupedGemmNoOpPtr>&);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
    std::vector<DeviceGroupedGemmNoOpPtr>&);
Jing Zhang's avatar
Jing Zhang committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

} // namespace device_grouped_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
          typename CDataType,
          typename ALayout,
          typename BLayout,
          typename CLayout>
void profile_grouped_gemm_impl(int do_verification,
Jing Zhang's avatar
Jing Zhang committed
48
49
50
51
52
53
54
55
56
                               int init_method,
                               bool do_log,
                               int nrepeat,
                               std::vector<int> Ms,
                               std::vector<int> Ns,
                               std::vector<int> Ks,
                               std::vector<int> StrideAs,
                               std::vector<int> StrideBs,
                               std::vector<int> StrideCs)
Jing Zhang's avatar
Jing Zhang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
{
    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({stride, 1}));
            }
            else
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({1, stride}));
            }
        };

Jing Zhang's avatar
Jing Zhang committed
72
73
74
75
    int group_count = Ms.size();

    if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
         group_count == StrideBs.size() && group_count == StrideCs.size()))
Jing Zhang's avatar
clean  
Jing Zhang committed
76
    {
Jing Zhang's avatar
Jing Zhang committed
77
        throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
Jing Zhang's avatar
clean  
Jing Zhang committed
78
79
    }

Jing Zhang's avatar
Jing Zhang committed
80
81
    std::vector<Tensor<ADataType>> a_m_k;
    std::vector<Tensor<BDataType>> b_k_n;
Jing Zhang's avatar
Jing Zhang committed
82
83
    std::vector<Tensor<CDataType>> c_m_n_device_results;

Jing Zhang's avatar
Jing Zhang committed
84
85
    for(int i = 0; i < Ms.size(); i++)
    {
Jing Zhang's avatar
Jing Zhang committed
86
87
88
89
90
91
92
        a_m_k.push_back(
            Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
        b_k_n.push_back(
            Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));

        c_m_n_device_results.push_back(
            Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
Jing Zhang's avatar
Jing Zhang committed
93

Jing Zhang's avatar
clean  
Jing Zhang committed
94
95
96
        std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
                  << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
                  << "]:" << c_m_n_device_results[i].mDesc << std::endl;
Jing Zhang's avatar
Jing Zhang committed
97
98
99
100

        std::size_t num_thread = std::thread::hardware_concurrency();
        switch(init_method)
        {
Jing Zhang's avatar
Jing Zhang committed
101
102
103
104
105
106
107
108
        case 0: break;
        case 1:
            a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
            b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
            break;
        default:
            a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
            b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
Jing Zhang's avatar
Jing Zhang committed
109
110
        }

Jing Zhang's avatar
Jing Zhang committed
111
112
        c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
    }
Jing Zhang's avatar
Jing Zhang committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    using AElementOp = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto a_element_op = AElementOp{};
    const auto b_element_op = BElementOp{};
    const auto c_element_op = CElementOp{};

    // if(do_verification)
    // {

    // }

Jing Zhang's avatar
clean  
Jing Zhang committed
127
128
    using DeviceMemPtr = std::unique_ptr<DeviceMem>;
    std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
Jing Zhang's avatar
Jing Zhang committed
129

Jing Zhang's avatar
Jing Zhang committed
130
131
132
    a_device_buf.reserve(group_count);
    b_device_buf.reserve(group_count);
    c_device_buf.reserve(group_count);
Jing Zhang's avatar
Jing Zhang committed
133

Jing Zhang's avatar
Jing Zhang committed
134
135
136
137
138
139
140
141
142
143
144
145
    std::vector<const void*> p_a, p_b;
    std::vector<void*> p_c;

    p_a.reserve(group_count);
    p_b.reserve(group_count);
    p_c.reserve(group_count);

    std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;

    gemm_shapes.reserve(group_count);

    for(int i = 0; i < group_count; i++)
Jing Zhang's avatar
Jing Zhang committed
146
    {
147
        a_device_buf.emplace_back(
Jing Zhang's avatar
clean  
Jing Zhang committed
148
            std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSize()));
149
        b_device_buf.emplace_back(
Jing Zhang's avatar
clean  
Jing Zhang committed
150
            std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSize()));
Jing Zhang's avatar
clean  
Jing Zhang committed
151

152
        c_device_buf.emplace_back(std::make_unique<DeviceMem>(
Jing Zhang's avatar
clean  
Jing Zhang committed
153
154
155
156
157
            sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSize()));

        a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
        b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
        c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data());
Jing Zhang's avatar
Jing Zhang committed
158

Jing Zhang's avatar
Jing Zhang committed
159
160
161
162
163
        gemm_shapes.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]});

        p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
        p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
        p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
Jing Zhang's avatar
Jing Zhang committed
164
165
166
    }

    // add device GEMM instances
Jing Zhang's avatar
Jing Zhang committed
167
168
169
    std::vector<
        ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr>
        gemm_ptrs;
Jing Zhang's avatar
Jing Zhang committed
170
171

    if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
Jing Zhang's avatar
Jing Zhang committed
172
                 is_same<CDataType, half_t>::value)
Jing Zhang's avatar
Jing Zhang committed
173
174
175
176
177
178
179
180
181
182
183
184
    {
        if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                     is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
            ck::tensor_operation::device::device_grouped_gemm_instance::
                add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
Jing Zhang's avatar
Jing Zhang committed
185
186
            ck::tensor_operation::device::device_grouped_gemm_instance::
                add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
Jing Zhang's avatar
Jing Zhang committed
187
188
189
190
191
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
Jing Zhang's avatar
Jing Zhang committed
192
193
            ck::tensor_operation::device::device_grouped_gemm_instance::
                add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
Jing Zhang's avatar
Jing Zhang committed
194
195
196
197
198
        }
        else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
                          is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
        {
Jing Zhang's avatar
Jing Zhang committed
199
200
            ck::tensor_operation::device::device_grouped_gemm_instance::
                add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
Jing Zhang's avatar
Jing Zhang committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        }
    }

    if(gemm_ptrs.size() <= 0)
    {
        throw std::runtime_error("wrong! no device GEMM instance found");
    }

    std::string best_gemm_name;
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

    // profile device GEMM instances
    for(auto& gemm_ptr : gemm_ptrs)
    {
        auto argument_ptr =
Jing Zhang's avatar
Jing Zhang committed
218
219
220
221
            gemm_ptr->MakeArgumentPointer(p_a,
                                          p_b,
                                          p_c,
                                          gemm_shapes,
Jing Zhang's avatar
Jing Zhang committed
222
223
                                          ck::tensor_operation::element_wise::PassThrough{},
                                          ck::tensor_operation::element_wise::PassThrough{},
Jing Zhang's avatar
Jing Zhang committed
224
                                          ck::tensor_operation::element_wise::PassThrough{});
Jing Zhang's avatar
Jing Zhang committed
225
226
227
228
229
230
231
232
233

        auto invoker_ptr = gemm_ptr->MakeInvokerPointer();

        if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
        {
            std::string gemm_name = gemm_ptr->GetTypeString();

            float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);

Jing Zhang's avatar
clean  
Jing Zhang committed
234
235
236
237
238
            std::size_t flop = 0, num_btype = 0;
            for(int i = 0; i < gemm_shapes.size(); i++)
            {
                flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];

Jing Zhang's avatar
clean  
Jing Zhang committed
239
                num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] +
Jing Zhang's avatar
clean  
Jing Zhang committed
240
241
                             sizeof(CDataType) * Ms[i] * Ns[i];
            }
Jing Zhang's avatar
Jing Zhang committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;
            std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
                      << gb_per_sec << " GB/s, " << gemm_name << std::endl;

            if(tflops > best_tflops)
            {
                best_gemm_name  = gemm_name;
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
Jing Zhang's avatar
Jing Zhang committed
259
260
                for(int i = 0; i < gemm_shapes.size(); i++)
                {
Jing Zhang's avatar
Jing Zhang committed
261

Jing Zhang's avatar
clean  
Jing Zhang committed
262
                    c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
Jing Zhang's avatar
Jing Zhang committed
263
264

                    Tensor<CDataType> c_m_n_host_result(
Jing Zhang's avatar
Jing Zhang committed
265
                        f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
Jing Zhang's avatar
Jing Zhang committed
266
267
268
269
270
271
272
273
274
275
276
277

                    using ReferenceGemmInstance =
                        ck::tensor_operation::host::ReferenceGemm<ADataType,
                                                                  BDataType,
                                                                  CDataType,
                                                                  AElementOp,
                                                                  BElementOp,
                                                                  CElementOp>;

                    auto ref_gemm    = ReferenceGemmInstance{};
                    auto ref_invoker = ref_gemm.MakeInvoker();

Jing Zhang's avatar
Jing Zhang committed
278
279
280
281
282
283
                    auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
                                                              b_k_n[i],
                                                              c_m_n_host_result,
                                                              a_element_op,
                                                              b_element_op,
                                                              c_element_op);
Jing Zhang's avatar
Jing Zhang committed
284
285

                    ref_invoker.Run(ref_argument);
Jing Zhang's avatar
Jing Zhang committed
286
                    check_error(c_m_n_host_result, c_m_n_device_results[i]);
Jing Zhang's avatar
Jing Zhang committed
287
288
289

                    if(do_log)
                    {
Jing Zhang's avatar
clean  
Jing Zhang committed
290
291
292
                        LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
                            << std::endl;
                        LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
Jing Zhang's avatar
Jing Zhang committed
293
                        LogRangeAsType<float>(
Jing Zhang's avatar
Jing Zhang committed
294
                            std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
Jing Zhang's avatar
Jing Zhang committed
295
                            << std::endl;
Jing Zhang's avatar
clean  
Jing Zhang committed
296
297
298
                        LogRangeAsType<float>(
                            std::cout << "c_host  : ", c_m_n_host_result.mData, ",")
                            << std::endl;
Jing Zhang's avatar
Jing Zhang committed
299
300
301
302
303
304
305
306
307
308
309
310
                    }
                }
            }
        }
        else
        {
            std::cout << "does not support this GEMM problem" << std::endl;
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
              << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
Jing Zhang's avatar
Jing Zhang committed
311
} // namespace profiler
Jing Zhang's avatar
Jing Zhang committed
312
313
314

} // namespace profiler
} // namespace ck