profile_batched_gemm_impl.hpp 9.58 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
#pragma once
5

Jianfeng Yan's avatar
Jianfeng Yan committed
6
#include <memory>
7

Chao Liu's avatar
Chao Liu committed
8
9
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
Chao Liu's avatar
Chao Liu committed
10
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
11
12
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

13
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
14

Chao Liu's avatar
Chao Liu committed
15
#include "ck/library/utility/check_err.hpp"
16
17
18
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
Chao Liu's avatar
Chao Liu committed
19
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
zjing14's avatar
zjing14 committed
20
21
22
23
24
25
26
27
28
29

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
          typename CDataType,
          typename ALayout,
          typename BLayout,
          typename CLayout>
30
bool profile_batched_gemm_impl(int do_verification,
zjing14's avatar
zjing14 committed
31
32
                               int init_method,
                               bool do_log,
JD's avatar
JD committed
33
                               bool time_kernel,
zjing14's avatar
zjing14 committed
34
35
36
                               int M,
                               int N,
                               int K,
37
38
39
                               int BatchStrideA,
                               int BatchStrideB,
                               int BatchStrideC,
zjing14's avatar
zjing14 committed
40
41
42
                               int StrideA,
                               int StrideB,
                               int StrideC,
43
                               int BatchCount)
zjing14's avatar
zjing14 committed
44
{
45
46
    bool pass = true;

zjing14's avatar
zjing14 committed
47
48
49
50
    auto f_host_tensor_descriptor = [](std::size_t batch_count,
                                       std::size_t row,
                                       std::size_t col,
                                       std::size_t stride,
51
                                       std::size_t batch_stride,
zjing14's avatar
zjing14 committed
52
53
54
55
                                       auto layout) {
        if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
        {
            return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
56
                                        std::vector<std::size_t>({batch_stride, stride, 1}));
zjing14's avatar
zjing14 committed
57
58
59
60
        }
        else
        {
            return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
61
                                        std::vector<std::size_t>({batch_stride, 1, stride}));
zjing14's avatar
zjing14 committed
62
63
64
        }
    };

65
66
67
68
    Tensor<ADataType> a_g_m_k(
        f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
    Tensor<BDataType> b_g_k_n(
        f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
zjing14's avatar
zjing14 committed
69
    Tensor<CDataType> c_g_m_n_host_result(
70
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
zjing14's avatar
zjing14 committed
71
    Tensor<CDataType> c_g_m_n_device_result(
72
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
zjing14's avatar
zjing14 committed
73
74
75
76
77
78
79
80
81

    std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
    std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
    std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;

    switch(init_method)
    {
    case 0: break;
    case 1:
Chao Liu's avatar
Chao Liu committed
82
83
        a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
zjing14's avatar
zjing14 committed
84
85
        break;
    default:
Chao Liu's avatar
Chao Liu committed
86
87
        a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
zjing14's avatar
zjing14 committed
88
89
90
91
92
93
94
95
96
97
98
99
    }

    using AElementOp = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto a_element_op = AElementOp{};
    const auto b_element_op = BElementOp{};
    const auto c_element_op = CElementOp{};

    if(do_verification)
    {
Chao Liu's avatar
Chao Liu committed
100
101
102
103
104
105
106
        using ReferenceBatchedGemmInstance =
            ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
                                                             BDataType,
                                                             CDataType,
                                                             AElementOp,
                                                             BElementOp,
                                                             CElementOp>;
Jianfeng Yan's avatar
Jianfeng Yan committed
107

Chao Liu's avatar
Chao Liu committed
108
109
        auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
        auto ref_invoker      = ref_batched_gemm.MakeInvoker();
zjing14's avatar
zjing14 committed
110

Chao Liu's avatar
Chao Liu committed
111
112
        auto ref_argument = ref_batched_gemm.MakeArgument(
            a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
zjing14's avatar
zjing14 committed
113

Chao Liu's avatar
Chao Liu committed
114
        ref_invoker.Run(ref_argument);
zjing14's avatar
zjing14 committed
115
116
    }

117
118
119
    DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
    DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
    DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
zjing14's avatar
zjing14 committed
120
121
122
123
124

    a_device_buf.ToDevice(a_g_m_k.mData.data());
    b_device_buf.ToDevice(b_g_k_n.mData.data());
    c_device_buf.ToDevice(c_g_m_n_device_result.mData.data());

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm<ALayout,
                                                                     BLayout,
                                                                     CLayout,
                                                                     ADataType,
                                                                     BDataType,
                                                                     CDataType,
                                                                     AElementOp,
                                                                     BElementOp,
                                                                     CElementOp>;

    // get device op instances
    const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
        DeviceOp>::GetInstances();

    std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
zjing14's avatar
zjing14 committed
140

Chao Liu's avatar
Chao Liu committed
141
    std::string best_op_name;
zjing14's avatar
zjing14 committed
142
143
144
145
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

Chao Liu's avatar
Chao Liu committed
146
147
    // profile device op instances
    for(auto& op_ptr : op_ptrs)
zjing14's avatar
zjing14 committed
148
149
    {
        auto argument_ptr =
Chao Liu's avatar
Chao Liu committed
150
151
152
153
154
155
156
157
158
            op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
                                        static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
                                        static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
                                        M,
                                        N,
                                        K,
                                        StrideA,
                                        StrideB,
                                        StrideC,
159
160
161
                                        BatchStrideA,
                                        BatchStrideB,
                                        BatchStrideC,
Chao Liu's avatar
Chao Liu committed
162
                                        BatchCount,
Chao Liu's avatar
Chao Liu committed
163
164
                                        ck::tensor_operation::element_wise::PassThrough{},
                                        ck::tensor_operation::element_wise::PassThrough{},
Chao Liu's avatar
Chao Liu committed
165
                                        ck::tensor_operation::element_wise::PassThrough{});
Chao Liu's avatar
Chao Liu committed
166
167
168
169

        auto invoker_ptr = op_ptr->MakeInvokerPointer();

        if(op_ptr->IsSupportedArgument(argument_ptr.get()))
zjing14's avatar
zjing14 committed
170
        {
Chao Liu's avatar
Chao Liu committed
171
172
173
174
            // re-init C to zero before profiling next kernel
            c_device_buf.SetZero();

            std::string op_name = op_ptr->GetTypeString();
zjing14's avatar
zjing14 committed
175

JD's avatar
JD committed
176
177
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
zjing14's avatar
zjing14 committed
178
179
180

            std::size_t flop = std::size_t(2) * BatchCount * M * N * K;

JD's avatar
JD committed
181
            std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
zjing14's avatar
zjing14 committed
182
183
184
185
186
187
188
189
                                     sizeof(CDataType) * M * N) *
                                    BatchCount;

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

            std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
Chao Liu's avatar
Chao Liu committed
190
                      << " GB/s, " << op_name << std::endl;
zjing14's avatar
zjing14 committed
191
192
193

            if(tflops > best_tflops)
            {
Chao Liu's avatar
Chao Liu committed
194
                best_op_name    = op_name;
zjing14's avatar
zjing14 committed
195
196
197
198
199
200
201
202
203
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
                c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());

Chao Liu's avatar
Chao Liu committed
204
205
                pass = pass &
                       ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData);
zjing14's avatar
zjing14 committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

                if(do_log)
                {
                    LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
                        << std::endl;
                    LogRangeAsType<float>(
                        std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
                        << std::endl;
                }
            }
        }
        else
        {
Chao Liu's avatar
Chao Liu committed
221
            std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
zjing14's avatar
zjing14 committed
222
223
224
225
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
Chao Liu's avatar
Chao Liu committed
226
              << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
227
228

    return pass;
zjing14's avatar
zjing14 committed
229
230
231
232
}

} // namespace profiler
} // namespace ck