profile_gemm_impl.hpp 9.76 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#pragma once
Chao Liu's avatar
Chao Liu committed
5

Chao Liu's avatar
Chao Liu committed
6
#include <iomanip>
7
8
#include <iostream>
#include <typeinfo>
9

Chao Liu's avatar
Chao Liu committed
10
11
12
13
14
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

15
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
Chao Liu's avatar
Chao Liu committed
16

Chao Liu's avatar
Chao Liu committed
17
18
19
20
21
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
22
23
24
25
26
27

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
28
          typename AccDataType,
Chao Liu's avatar
Chao Liu committed
29
          typename CDataType,
30
31
32
          typename ALayout,
          typename BLayout,
          typename CLayout>
Chao Liu's avatar
Chao Liu committed
33
34
35
36
37
38
39
40
41
42
int profile_gemm_impl(int do_verification,
                      int init_method,
                      bool do_log,
                      bool time_kernel,
                      int M,
                      int N,
                      int K,
                      int StrideA,
                      int StrideB,
                      int StrideC)
43
{
Chao Liu's avatar
Chao Liu committed
44
45
    bool pass = true;

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({stride, 1}));
            }
            else
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({1, stride}));
            }
        };

    Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
    Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Chao Liu's avatar
Chao Liu committed
62
    Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
63
64
65
66
    Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));

    std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
    std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
67
    std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
68
69
70

    switch(init_method)
    {
Chao Liu's avatar
Chao Liu committed
71
    case 0: break;
72
    case 1:
Chao Liu's avatar
Chao Liu committed
73
74
        a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
        b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
75
76
        break;
    default:
Chao Liu's avatar
Chao Liu committed
77
78
        a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
        b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
79
    }
Chao Liu's avatar
Chao Liu committed
80
81
82
83
84
85
86
87
88

    using AElementOp = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto a_element_op = AElementOp{};
    const auto b_element_op = BElementOp{};
    const auto c_element_op = CElementOp{};

89
90
91
92
93
94
95
96
    DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
    DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
    DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());

    a_device_buf.ToDevice(a_m_k.mData.data());
    b_device_buf.ToDevice(b_k_n.mData.data());
    c_device_buf.ToDevice(c_m_n_device_result.mData.data());

97
98
99
100
101
102
103
104
105
    using DeviceOp = ck::tensor_operation::device::DeviceGemm<ALayout,
                                                              BLayout,
                                                              CLayout,
                                                              ADataType,
                                                              BDataType,
                                                              CDataType,
                                                              AElementOp,
                                                              BElementOp,
                                                              CElementOp>;
106

107
108
109
110
111
    // get device op instances
    const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
        DeviceOp>::GetInstances();

    std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
Chao Liu's avatar
Chao Liu committed
112

Chao Liu's avatar
Chao Liu committed
113
114
    // Run reference GEMM
    if(do_verification)
115
    {
Chao Liu's avatar
Chao Liu committed
116
117
118
119
120
121
122
        using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
                                                                                BDataType,
                                                                                CDataType,
                                                                                AccDataType,
                                                                                AElementOp,
                                                                                BElementOp,
                                                                                CElementOp>;
Jianfeng Yan's avatar
Jianfeng Yan committed
123

Chao Liu's avatar
Chao Liu committed
124
125
        auto ref_op      = ReferenceGemmInstance{};
        auto ref_invoker = ref_op.MakeInvoker();
Jianfeng Yan's avatar
Jianfeng Yan committed
126

Chao Liu's avatar
Chao Liu committed
127
128
        auto ref_argument = ref_op.MakeArgument(
            a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
129

Chao Liu's avatar
Chao Liu committed
130
        ref_invoker.Run(ref_argument);
131
132
    }

Chao Liu's avatar
Chao Liu committed
133
    std::string best_op_name;
134
135
136
137
138
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

    // profile device GEMM instances
Chao Liu's avatar
Chao Liu committed
139
    for(auto& op_ptr : op_ptrs)
140
141
    {
        auto argument_ptr =
Chao Liu's avatar
Chao Liu committed
142
143
144
145
146
147
148
149
150
            op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
                                        static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
                                        static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
                                        M,
                                        N,
                                        K,
                                        StrideA,
                                        StrideB,
                                        StrideC,
151
152
153
                                        a_element_op,
                                        b_element_op,
                                        c_element_op);
Chao Liu's avatar
Chao Liu committed
154
155
156
157

        auto invoker_ptr = op_ptr->MakeInvokerPointer();

        if(op_ptr->IsSupportedArgument(argument_ptr.get()))
158
        {
159
            // re-init C to zero before profiling next kernel
Chao Liu's avatar
Chao Liu committed
160
            c_device_buf.SetZero();
161

Chao Liu's avatar
Chao Liu committed
162
            std::string op_name = op_ptr->GetTypeString();
Chao Liu's avatar
Chao Liu committed
163

JD's avatar
JD committed
164
165
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
166
167

            std::size_t flop = std::size_t(2) * M * N * K;
Chao Liu's avatar
Chao Liu committed
168

169
            std::size_t num_btype =
Chao Liu's avatar
Chao Liu committed
170
                sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
171
172
173
174
175

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

Chao Liu's avatar
Chao Liu committed
176
            std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
Chao Liu's avatar
Chao Liu committed
177
                      << gb_per_sec << " GB/s, " << op_name << std::endl;
178
179
180

            if(tflops > best_tflops)
            {
Chao Liu's avatar
Chao Liu committed
181
                best_op_name    = op_name;
182
183
184
185
186
187
188
189
190
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
                c_device_buf.FromDevice(c_m_n_device_result.mData.data());

Chao Liu's avatar
Chao Liu committed
191
192
                pass =
                    pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
193
194
195
196
197

                if(do_log)
                {
                    LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
Chao Liu's avatar
Chao Liu committed
198
199
                    LogRangeAsType<float>(std::cout << "c_host  : ", c_m_n_host_result.mData, ",")
                        << std::endl;
200
201
202
203
204
205
206
                    LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
                        << std::endl;
                }
            }
        }
        else
        {
Chao Liu's avatar
Chao Liu committed
207
            std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
208
209
210
        }
    }

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    if constexpr(is_same<CDataType, float>::value)
    {
        std::cout << "Best Perf for datatype = f32";
    }
    else if constexpr(is_same<CDataType, half_t>::value)
    {
        std::cout << "Best Perf for datatype = f16";
    }
    else if constexpr(is_same<CDataType, bhalf_t>::value)
    {
        std::cout << "Best Perf for datatype = bf16";
    }
    else if constexpr(is_same<CDataType, int8_t>::value)
    {
        std::cout << "Best Perf for datatype = int8";
    }

    if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
    {
        std::cout << " ALayout =  RowMajor";
    }
    else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
    {
        std::cout << " ALayout =  ColumnMajor";
    }

    if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
    {
        std::cout << " BLayout =  RowMajor";
    }
    else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
    {
        std::cout << " BLayout =  ColumnMajor";
    }

    std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
              << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time
              << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
Chao Liu's avatar
Chao Liu committed
249
250
251
              << best_op_name << std::endl;

    return pass ? 0 : 1;
252
253
254
255
}

} // namespace profiler
} // namespace ck