profile_batched_gemm_impl.hpp 9.65 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
#pragma once
5

Jianfeng Yan's avatar
Jianfeng Yan committed
6
#include <memory>
7

Chao Liu's avatar
Chao Liu committed
8
9
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
Chao Liu's avatar
Chao Liu committed
10
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
11
12
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

13
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
14

Chao Liu's avatar
Chao Liu committed
15
#include "ck/library/utility/check_err.hpp"
16
17
18
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
Chao Liu's avatar
Chao Liu committed
19
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
zjing14's avatar
zjing14 committed
20
21
22
23
24
25
26
27
28
29

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
          typename CDataType,
          typename ALayout,
          typename BLayout,
          typename CLayout>
30
bool profile_batched_gemm_impl(int do_verification,
zjing14's avatar
zjing14 committed
31
32
                               int init_method,
                               bool do_log,
JD's avatar
JD committed
33
                               bool time_kernel,
zjing14's avatar
zjing14 committed
34
35
36
                               int M,
                               int N,
                               int K,
37
38
39
                               int BatchStrideA,
                               int BatchStrideB,
                               int BatchStrideC,
zjing14's avatar
zjing14 committed
40
41
42
                               int StrideA,
                               int StrideB,
                               int StrideC,
43
                               int BatchCount)
zjing14's avatar
zjing14 committed
44
{
45
46
    bool pass = true;

zjing14's avatar
zjing14 committed
47
48
49
50
    auto f_host_tensor_descriptor = [](std::size_t batch_count,
                                       std::size_t row,
                                       std::size_t col,
                                       std::size_t stride,
51
                                       std::size_t batch_stride,
zjing14's avatar
zjing14 committed
52
53
54
55
                                       auto layout) {
        if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
        {
            return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
56
                                        std::vector<std::size_t>({batch_stride, stride, 1}));
zjing14's avatar
zjing14 committed
57
58
59
60
        }
        else
        {
            return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
61
                                        std::vector<std::size_t>({batch_stride, 1, stride}));
zjing14's avatar
zjing14 committed
62
63
64
        }
    };

65
66
67
68
    Tensor<ADataType> a_g_m_k(
        f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
    Tensor<BDataType> b_g_k_n(
        f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
zjing14's avatar
zjing14 committed
69
    Tensor<CDataType> c_g_m_n_host_result(
70
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
zjing14's avatar
zjing14 committed
71
    Tensor<CDataType> c_g_m_n_device_result(
72
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
zjing14's avatar
zjing14 committed
73
74
75
76
77
78
79
80
81

    std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
    std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
    std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;

    switch(init_method)
    {
    case 0: break;
    case 1:
Chao Liu's avatar
Chao Liu committed
82
83
        a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
zjing14's avatar
zjing14 committed
84
85
        break;
    default:
Chao Liu's avatar
Chao Liu committed
86
87
        a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
zjing14's avatar
zjing14 committed
88
89
90
91
92
93
94
95
96
97
98
99
    }

    using AElementOp = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto a_element_op = AElementOp{};
    const auto b_element_op = BElementOp{};
    const auto c_element_op = CElementOp{};

    if(do_verification)
    {
Chao Liu's avatar
Chao Liu committed
100
101
102
103
        using ReferenceBatchedGemmInstance =
            ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
                                                             BDataType,
                                                             CDataType,
Anthony Chang's avatar
Anthony Chang committed
104
                                                             float,
Chao Liu's avatar
Chao Liu committed
105
106
107
                                                             AElementOp,
                                                             BElementOp,
                                                             CElementOp>;
Jianfeng Yan's avatar
Jianfeng Yan committed
108

Chao Liu's avatar
Chao Liu committed
109
110
        auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
        auto ref_invoker      = ref_batched_gemm.MakeInvoker();
zjing14's avatar
zjing14 committed
111

Chao Liu's avatar
Chao Liu committed
112
113
        auto ref_argument = ref_batched_gemm.MakeArgument(
            a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
zjing14's avatar
zjing14 committed
114

Chao Liu's avatar
Chao Liu committed
115
        ref_invoker.Run(ref_argument);
zjing14's avatar
zjing14 committed
116
117
    }

118
119
120
    DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
    DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
    DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
zjing14's avatar
zjing14 committed
121
122
123
124
125

    a_device_buf.ToDevice(a_g_m_k.mData.data());
    b_device_buf.ToDevice(b_g_k_n.mData.data());
    c_device_buf.ToDevice(c_g_m_n_device_result.mData.data());

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm<ALayout,
                                                                     BLayout,
                                                                     CLayout,
                                                                     ADataType,
                                                                     BDataType,
                                                                     CDataType,
                                                                     AElementOp,
                                                                     BElementOp,
                                                                     CElementOp>;

    // get device op instances
    const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
        DeviceOp>::GetInstances();

    std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
zjing14's avatar
zjing14 committed
141

Chao Liu's avatar
Chao Liu committed
142
    std::string best_op_name;
zjing14's avatar
zjing14 committed
143
144
145
146
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

Chao Liu's avatar
Chao Liu committed
147
148
    // profile device op instances
    for(auto& op_ptr : op_ptrs)
zjing14's avatar
zjing14 committed
149
150
    {
        auto argument_ptr =
Chao Liu's avatar
Chao Liu committed
151
152
153
154
155
156
157
158
159
            op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
                                        static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
                                        static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
                                        M,
                                        N,
                                        K,
                                        StrideA,
                                        StrideB,
                                        StrideC,
160
161
162
                                        BatchStrideA,
                                        BatchStrideB,
                                        BatchStrideC,
Chao Liu's avatar
Chao Liu committed
163
                                        BatchCount,
Chao Liu's avatar
Chao Liu committed
164
165
                                        ck::tensor_operation::element_wise::PassThrough{},
                                        ck::tensor_operation::element_wise::PassThrough{},
Chao Liu's avatar
Chao Liu committed
166
                                        ck::tensor_operation::element_wise::PassThrough{});
Chao Liu's avatar
Chao Liu committed
167
168
169
170

        auto invoker_ptr = op_ptr->MakeInvokerPointer();

        if(op_ptr->IsSupportedArgument(argument_ptr.get()))
zjing14's avatar
zjing14 committed
171
        {
Chao Liu's avatar
Chao Liu committed
172
173
174
175
            // re-init C to zero before profiling next kernel
            c_device_buf.SetZero();

            std::string op_name = op_ptr->GetTypeString();
zjing14's avatar
zjing14 committed
176

JD's avatar
JD committed
177
178
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
zjing14's avatar
zjing14 committed
179
180
181

            std::size_t flop = std::size_t(2) * BatchCount * M * N * K;

JD's avatar
JD committed
182
            std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
zjing14's avatar
zjing14 committed
183
184
185
186
187
188
189
190
                                     sizeof(CDataType) * M * N) *
                                    BatchCount;

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

            std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
Chao Liu's avatar
Chao Liu committed
191
                      << " GB/s, " << op_name << std::endl;
zjing14's avatar
zjing14 committed
192
193
194

            if(tflops > best_tflops)
            {
Chao Liu's avatar
Chao Liu committed
195
                best_op_name    = op_name;
zjing14's avatar
zjing14 committed
196
197
198
199
200
201
202
203
204
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
                c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());

Chao Liu's avatar
Chao Liu committed
205
206
                pass = pass &
                       ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData);
zjing14's avatar
zjing14 committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

                if(do_log)
                {
                    LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
                        << std::endl;
                    LogRangeAsType<float>(
                        std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
                        << std::endl;
                }
            }
        }
        else
        {
Chao Liu's avatar
Chao Liu committed
222
            std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
zjing14's avatar
zjing14 committed
223
224
225
226
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
Chao Liu's avatar
Chao Liu committed
227
              << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
228
229

    return pass;
zjing14's avatar
zjing14 committed
230
231
232
233
}

} // namespace profiler
} // namespace ck