"macapp/src/ollama.svg" did not exist on "e88dd25babdf8c09e0010aea8ba754df7eb2191d"
profile_batched_gemm_impl.hpp 9.52 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
#pragma once
5

Jianfeng Yan's avatar
Jianfeng Yan committed
6
#include <memory>
7

Chao Liu's avatar
Chao Liu committed
8
9
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
Chao Liu's avatar
Chao Liu committed
10
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
11
12
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

13
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
14

Chao Liu's avatar
Chao Liu committed
15
#include "ck/library/utility/check_err.hpp"
16
17
18
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
19
#include "ck/library/utility/literals.hpp"
Chao Liu's avatar
Chao Liu committed
20
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
zjing14's avatar
zjing14 committed
21
22
23
24
25
26
27
28
29
30

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
          typename CDataType,
          typename ALayout,
          typename BLayout,
          typename CLayout>
31
bool profile_batched_gemm_impl(int do_verification,
zjing14's avatar
zjing14 committed
32
33
                               int init_method,
                               bool do_log,
JD's avatar
JD committed
34
                               bool time_kernel,
zjing14's avatar
zjing14 committed
35
36
37
                               int M,
                               int N,
                               int K,
38
39
40
                               int BatchStrideA,
                               int BatchStrideB,
                               int BatchStrideC,
zjing14's avatar
zjing14 committed
41
42
43
                               int StrideA,
                               int StrideB,
                               int StrideC,
44
                               int BatchCount)
zjing14's avatar
zjing14 committed
45
{
46
47
    bool pass = true;

zjing14's avatar
zjing14 committed
48
49
50
51
    auto f_host_tensor_descriptor = [](std::size_t batch_count,
                                       std::size_t row,
                                       std::size_t col,
                                       std::size_t stride,
52
                                       std::size_t batch_stride,
zjing14's avatar
zjing14 committed
53
                                       auto layout) {
54
55
        using namespace ck::literals;

zjing14's avatar
zjing14 committed
56
57
        if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
        {
58
            return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
zjing14's avatar
zjing14 committed
59
60
61
        }
        else
        {
62
            return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
zjing14's avatar
zjing14 committed
63
64
65
        }
    };

66
67
68
69
    Tensor<ADataType> a_g_m_k(
        f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
    Tensor<BDataType> b_g_k_n(
        f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
zjing14's avatar
zjing14 committed
70
    Tensor<CDataType> c_g_m_n_host_result(
71
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
zjing14's avatar
zjing14 committed
72
    Tensor<CDataType> c_g_m_n_device_result(
73
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
zjing14's avatar
zjing14 committed
74
75
76
77
78
79
80
81
82

    std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
    std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
    std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;

    switch(init_method)
    {
    case 0: break;
    case 1:
Chao Liu's avatar
Chao Liu committed
83
84
        a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
zjing14's avatar
zjing14 committed
85
86
        break;
    default:
Chao Liu's avatar
Chao Liu committed
87
88
        a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
zjing14's avatar
zjing14 committed
89
90
91
92
93
94
95
96
97
98
99
100
    }

    using AElementOp = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto a_element_op = AElementOp{};
    const auto b_element_op = BElementOp{};
    const auto c_element_op = CElementOp{};

    if(do_verification)
    {
Chao Liu's avatar
Chao Liu committed
101
102
103
104
        using ReferenceBatchedGemmInstance =
            ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
                                                             BDataType,
                                                             CDataType,
Anthony Chang's avatar
Anthony Chang committed
105
                                                             float,
Chao Liu's avatar
Chao Liu committed
106
107
108
                                                             AElementOp,
                                                             BElementOp,
                                                             CElementOp>;
Jianfeng Yan's avatar
Jianfeng Yan committed
109

Chao Liu's avatar
Chao Liu committed
110
111
        auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
        auto ref_invoker      = ref_batched_gemm.MakeInvoker();
zjing14's avatar
zjing14 committed
112

Chao Liu's avatar
Chao Liu committed
113
114
        auto ref_argument = ref_batched_gemm.MakeArgument(
            a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
zjing14's avatar
zjing14 committed
115

Chao Liu's avatar
Chao Liu committed
116
        ref_invoker.Run(ref_argument);
zjing14's avatar
zjing14 committed
117
118
    }

119
120
121
    DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
    DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
    DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
zjing14's avatar
zjing14 committed
122
123
124
125
126

    a_device_buf.ToDevice(a_g_m_k.mData.data());
    b_device_buf.ToDevice(b_g_k_n.mData.data());
    c_device_buf.ToDevice(c_g_m_n_device_result.mData.data());

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm<ALayout,
                                                                     BLayout,
                                                                     CLayout,
                                                                     ADataType,
                                                                     BDataType,
                                                                     CDataType,
                                                                     AElementOp,
                                                                     BElementOp,
                                                                     CElementOp>;

    // get device op instances
    const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
        DeviceOp>::GetInstances();

    std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
zjing14's avatar
zjing14 committed
142

Chao Liu's avatar
Chao Liu committed
143
    std::string best_op_name;
zjing14's avatar
zjing14 committed
144
145
146
147
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

Chao Liu's avatar
Chao Liu committed
148
149
    // profile device op instances
    for(auto& op_ptr : op_ptrs)
zjing14's avatar
zjing14 committed
150
151
    {
        auto argument_ptr =
Chao Liu's avatar
Chao Liu committed
152
153
154
155
156
157
158
159
160
            op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
                                        static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
                                        static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
                                        M,
                                        N,
                                        K,
                                        StrideA,
                                        StrideB,
                                        StrideC,
161
162
163
                                        BatchStrideA,
                                        BatchStrideB,
                                        BatchStrideC,
Chao Liu's avatar
Chao Liu committed
164
                                        BatchCount,
Chao Liu's avatar
Chao Liu committed
165
166
                                        ck::tensor_operation::element_wise::PassThrough{},
                                        ck::tensor_operation::element_wise::PassThrough{},
Chao Liu's avatar
Chao Liu committed
167
                                        ck::tensor_operation::element_wise::PassThrough{});
Chao Liu's avatar
Chao Liu committed
168
169
170
171

        auto invoker_ptr = op_ptr->MakeInvokerPointer();

        if(op_ptr->IsSupportedArgument(argument_ptr.get()))
zjing14's avatar
zjing14 committed
172
        {
Chao Liu's avatar
Chao Liu committed
173
174
175
176
            // re-init C to zero before profiling next kernel
            c_device_buf.SetZero();

            std::string op_name = op_ptr->GetTypeString();
zjing14's avatar
zjing14 committed
177

JD's avatar
JD committed
178
179
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
zjing14's avatar
zjing14 committed
180
181
182

            std::size_t flop = std::size_t(2) * BatchCount * M * N * K;

JD's avatar
JD committed
183
            std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
zjing14's avatar
zjing14 committed
184
185
186
187
188
189
190
191
                                     sizeof(CDataType) * M * N) *
                                    BatchCount;

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

            std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
Chao Liu's avatar
Chao Liu committed
192
                      << " GB/s, " << op_name << std::endl;
zjing14's avatar
zjing14 committed
193
194
195

            if(tflops > best_tflops)
            {
Chao Liu's avatar
Chao Liu committed
196
                best_op_name    = op_name;
zjing14's avatar
zjing14 committed
197
198
199
200
201
202
203
204
205
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
                c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());

206
                pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
zjing14's avatar
zjing14 committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

                if(do_log)
                {
                    LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
                        << std::endl;
                    LogRangeAsType<float>(
                        std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
                        << std::endl;
                }
            }
        }
        else
        {
Chao Liu's avatar
Chao Liu committed
222
            std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
zjing14's avatar
zjing14 committed
223
224
225
226
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
Chao Liu's avatar
Chao Liu committed
227
              << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
228
229

    return pass;
zjing14's avatar
zjing14 committed
230
231
232
233
}

} // namespace profiler
} // namespace ck