profile_batched_gemm_impl.hpp 9.15 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
#pragma once
5

Jianfeng Yan's avatar
Jianfeng Yan committed
6
#include <memory>
7

Chao Liu's avatar
Chao Liu committed
8
9
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
Chao Liu's avatar
Chao Liu committed
10
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
11
12
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

13
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
14

Chao Liu's avatar
Chao Liu committed
15
16
17
18
19
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
zjing14's avatar
zjing14 committed
20
21
22
23
24
25
26
27
28
29

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
          typename CDataType,
          typename ALayout,
          typename BLayout,
          typename CLayout>
30
bool profile_batched_gemm_impl(int do_verification,
zjing14's avatar
zjing14 committed
31
32
                               int init_method,
                               bool do_log,
JD's avatar
JD committed
33
                               bool time_kernel,
zjing14's avatar
zjing14 committed
34
35
36
37
38
39
                               int M,
                               int N,
                               int K,
                               int StrideA,
                               int StrideB,
                               int StrideC,
40
                               int BatchCount)
zjing14's avatar
zjing14 committed
41
{
42
43
    bool pass = true;

zjing14's avatar
zjing14 committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    auto f_host_tensor_descriptor = [](std::size_t batch_count,
                                       std::size_t row,
                                       std::size_t col,
                                       std::size_t stride,
                                       auto layout) {
        if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
        {
            return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
                                        std::vector<std::size_t>({row * stride, stride, 1}));
        }
        else
        {
            return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
                                        std::vector<std::size_t>({col * stride, 1, stride}));
        }
    };

    Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{}));
    Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{}));
    Tensor<CDataType> c_g_m_n_host_result(
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
    Tensor<CDataType> c_g_m_n_device_result(
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));

    std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
    std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
    std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;

    switch(init_method)
    {
    case 0: break;
    case 1:
Chao Liu's avatar
Chao Liu committed
76
77
        a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
zjing14's avatar
zjing14 committed
78
79
        break;
    default:
Chao Liu's avatar
Chao Liu committed
80
81
        a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
zjing14's avatar
zjing14 committed
82
83
84
85
86
87
88
89
90
91
92
93
    }

    using AElementOp = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto a_element_op = AElementOp{};
    const auto b_element_op = BElementOp{};
    const auto c_element_op = CElementOp{};

    if(do_verification)
    {
Chao Liu's avatar
Chao Liu committed
94
95
96
97
98
99
100
        using ReferenceBatchedGemmInstance =
            ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
                                                             BDataType,
                                                             CDataType,
                                                             AElementOp,
                                                             BElementOp,
                                                             CElementOp>;
Jianfeng Yan's avatar
Jianfeng Yan committed
101

Chao Liu's avatar
Chao Liu committed
102
103
        auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
        auto ref_invoker      = ref_batched_gemm.MakeInvoker();
zjing14's avatar
zjing14 committed
104

Chao Liu's avatar
Chao Liu committed
105
106
        auto ref_argument = ref_batched_gemm.MakeArgument(
            a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
zjing14's avatar
zjing14 committed
107

Chao Liu's avatar
Chao Liu committed
108
        ref_invoker.Run(ref_argument);
zjing14's avatar
zjing14 committed
109
110
111
112
113
114
115
116
117
118
    }

    DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
    DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
    DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());

    a_device_buf.ToDevice(a_g_m_k.mData.data());
    b_device_buf.ToDevice(b_g_k_n.mData.data());
    c_device_buf.ToDevice(c_g_m_n_device_result.mData.data());

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm<ALayout,
                                                                     BLayout,
                                                                     CLayout,
                                                                     ADataType,
                                                                     BDataType,
                                                                     CDataType,
                                                                     AElementOp,
                                                                     BElementOp,
                                                                     CElementOp>;

    // get device op instances
    const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
        DeviceOp>::GetInstances();

    std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
zjing14's avatar
zjing14 committed
134

Chao Liu's avatar
Chao Liu committed
135
    std::string best_op_name;
zjing14's avatar
zjing14 committed
136
137
138
139
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

Chao Liu's avatar
Chao Liu committed
140
141
    // profile device op instances
    for(auto& op_ptr : op_ptrs)
zjing14's avatar
zjing14 committed
142
143
    {
        auto argument_ptr =
Chao Liu's avatar
Chao Liu committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
                                        static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
                                        static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
                                        M,
                                        N,
                                        K,
                                        StrideA,
                                        StrideB,
                                        StrideC,
                                        ck::tensor_operation::element_wise::PassThrough{},
                                        ck::tensor_operation::element_wise::PassThrough{},
                                        ck::tensor_operation::element_wise::PassThrough{},
                                        BatchCount);

        auto invoker_ptr = op_ptr->MakeInvokerPointer();

        if(op_ptr->IsSupportedArgument(argument_ptr.get()))
zjing14's avatar
zjing14 committed
161
        {
Chao Liu's avatar
Chao Liu committed
162
163
164
165
            // re-init C to zero before profiling next kernel
            c_device_buf.SetZero();

            std::string op_name = op_ptr->GetTypeString();
zjing14's avatar
zjing14 committed
166

JD's avatar
JD committed
167
168
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
zjing14's avatar
zjing14 committed
169
170
171

            std::size_t flop = std::size_t(2) * BatchCount * M * N * K;

JD's avatar
JD committed
172
            std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
zjing14's avatar
zjing14 committed
173
174
175
176
177
178
179
180
                                     sizeof(CDataType) * M * N) *
                                    BatchCount;

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

            std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
Chao Liu's avatar
Chao Liu committed
181
                      << " GB/s, " << op_name << std::endl;
zjing14's avatar
zjing14 committed
182
183
184

            if(tflops > best_tflops)
            {
Chao Liu's avatar
Chao Liu committed
185
                best_op_name    = op_name;
zjing14's avatar
zjing14 committed
186
187
188
189
190
191
192
193
194
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
                c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());

Chao Liu's avatar
Chao Liu committed
195
196
                pass = pass &
                       ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData);
zjing14's avatar
zjing14 committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

                if(do_log)
                {
                    LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
                        << std::endl;
                    LogRangeAsType<float>(
                        std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
                        << std::endl;
                }
            }
        }
        else
        {
Chao Liu's avatar
Chao Liu committed
212
            std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
zjing14's avatar
zjing14 committed
213
214
215
216
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
Chao Liu's avatar
Chao Liu committed
217
              << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
218
219

    return pass;
zjing14's avatar
zjing14 committed
220
221
222
223
}

} // namespace profiler
} // namespace ck