profile_batched_gemm_impl.hpp 8.76 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
#pragma once
5

Jianfeng Yan's avatar
Jianfeng Yan committed
6
#include <memory>
7

Chao Liu's avatar
Chao Liu committed
8
9
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
Chao Liu's avatar
Chao Liu committed
10
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
11
12
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

Chao Liu's avatar
Chao Liu committed
13
14
#include "ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp"

Chao Liu's avatar
Chao Liu committed
15
16
17
18
19
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
zjing14's avatar
zjing14 committed
20
21
22
23
24
25
26
27
28
29

namespace ck {
namespace profiler {

template <typename ADataType,
          typename BDataType,
          typename CDataType,
          typename ALayout,
          typename BLayout,
          typename CLayout>
30
bool profile_batched_gemm_impl(int do_verification,
zjing14's avatar
zjing14 committed
31
32
                               int init_method,
                               bool do_log,
JD's avatar
JD committed
33
                               bool time_kernel,
zjing14's avatar
zjing14 committed
34
35
36
37
38
39
                               int M,
                               int N,
                               int K,
                               int StrideA,
                               int StrideB,
                               int StrideC,
40
                               int BatchCount)
zjing14's avatar
zjing14 committed
41
{
42
43
    bool pass = true;

zjing14's avatar
zjing14 committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    auto f_host_tensor_descriptor = [](std::size_t batch_count,
                                       std::size_t row,
                                       std::size_t col,
                                       std::size_t stride,
                                       auto layout) {
        if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
        {
            return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
                                        std::vector<std::size_t>({row * stride, stride, 1}));
        }
        else
        {
            return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
                                        std::vector<std::size_t>({col * stride, 1, stride}));
        }
    };

    Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{}));
    Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{}));
    Tensor<CDataType> c_g_m_n_host_result(
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
    Tensor<CDataType> c_g_m_n_device_result(
        f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));

    std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
    std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
    std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;

    switch(init_method)
    {
    case 0: break;
    case 1:
Chao Liu's avatar
Chao Liu committed
76
77
        a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
zjing14's avatar
zjing14 committed
78
79
        break;
    default:
Chao Liu's avatar
Chao Liu committed
80
81
        a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
        b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
zjing14's avatar
zjing14 committed
82
83
84
85
86
87
88
89
90
91
92
93
    }

    using AElementOp = ck::tensor_operation::element_wise::PassThrough;
    using BElementOp = ck::tensor_operation::element_wise::PassThrough;
    using CElementOp = ck::tensor_operation::element_wise::PassThrough;

    const auto a_element_op = AElementOp{};
    const auto b_element_op = BElementOp{};
    const auto c_element_op = CElementOp{};

    if(do_verification)
    {
Chao Liu's avatar
Chao Liu committed
94
95
96
97
98
99
100
        using ReferenceBatchedGemmInstance =
            ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
                                                             BDataType,
                                                             CDataType,
                                                             AElementOp,
                                                             BElementOp,
                                                             CElementOp>;
Jianfeng Yan's avatar
Jianfeng Yan committed
101

Chao Liu's avatar
Chao Liu committed
102
103
        auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
        auto ref_invoker      = ref_batched_gemm.MakeInvoker();
zjing14's avatar
zjing14 committed
104

Chao Liu's avatar
Chao Liu committed
105
106
        auto ref_argument = ref_batched_gemm.MakeArgument(
            a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
zjing14's avatar
zjing14 committed
107

Chao Liu's avatar
Chao Liu committed
108
        ref_invoker.Run(ref_argument);
zjing14's avatar
zjing14 committed
109
110
111
112
113
114
115
116
117
118
    }

    DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
    DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
    DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());

    a_device_buf.ToDevice(a_g_m_k.mData.data());
    b_device_buf.ToDevice(b_g_k_n.mData.data());
    c_device_buf.ToDevice(c_g_m_n_device_result.mData.data());

Chao Liu's avatar
Chao Liu committed
119
120
121
122
123
124
125
126
    // add device op instances
    const auto op_ptrs = ck::tensor_operation::device::device_batched_gemm_instance::
        get_device_batched_gemm_instances<ADataType,
                                          BDataType,
                                          CDataType,
                                          ALayout,
                                          BLayout,
                                          CLayout>();
zjing14's avatar
zjing14 committed
127

Chao Liu's avatar
Chao Liu committed
128
    if(op_ptrs.size() <= 0)
zjing14's avatar
zjing14 committed
129
130
131
132
    {
        throw std::runtime_error("wrong! no device GEMM instance found");
    }

Chao Liu's avatar
Chao Liu committed
133
    std::string best_op_name;
zjing14's avatar
zjing14 committed
134
135
136
137
    float best_ave_time   = 0;
    float best_tflops     = 0;
    float best_gb_per_sec = 0;

Chao Liu's avatar
Chao Liu committed
138
139
    // profile device op instances
    for(auto& op_ptr : op_ptrs)
zjing14's avatar
zjing14 committed
140
141
    {
        auto argument_ptr =
Chao Liu's avatar
Chao Liu committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
                                        static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
                                        static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
                                        M,
                                        N,
                                        K,
                                        StrideA,
                                        StrideB,
                                        StrideC,
                                        ck::tensor_operation::element_wise::PassThrough{},
                                        ck::tensor_operation::element_wise::PassThrough{},
                                        ck::tensor_operation::element_wise::PassThrough{},
                                        BatchCount);

        auto invoker_ptr = op_ptr->MakeInvokerPointer();

        if(op_ptr->IsSupportedArgument(argument_ptr.get()))
zjing14's avatar
zjing14 committed
159
        {
Chao Liu's avatar
Chao Liu committed
160
161
162
163
            // re-init C to zero before profiling next kernel
            c_device_buf.SetZero();

            std::string op_name = op_ptr->GetTypeString();
zjing14's avatar
zjing14 committed
164

JD's avatar
JD committed
165
166
            float ave_time =
                invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
zjing14's avatar
zjing14 committed
167
168
169

            std::size_t flop = std::size_t(2) * BatchCount * M * N * K;

JD's avatar
JD committed
170
            std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
zjing14's avatar
zjing14 committed
171
172
173
174
175
176
177
178
                                     sizeof(CDataType) * M * N) *
                                    BatchCount;

            float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

            float gb_per_sec = num_btype / 1.E6 / ave_time;

            std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
Chao Liu's avatar
Chao Liu committed
179
                      << " GB/s, " << op_name << std::endl;
zjing14's avatar
zjing14 committed
180
181
182

            if(tflops > best_tflops)
            {
Chao Liu's avatar
Chao Liu committed
183
                best_op_name    = op_name;
zjing14's avatar
zjing14 committed
184
185
186
187
188
189
190
191
192
                best_tflops     = tflops;
                best_ave_time   = ave_time;
                best_gb_per_sec = gb_per_sec;
            }

            if(do_verification)
            {
                c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());

Chao Liu's avatar
Chao Liu committed
193
194
                pass = pass &
                       ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData);
zjing14's avatar
zjing14 committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

                if(do_log)
                {
                    LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
                    LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
                        << std::endl;
                    LogRangeAsType<float>(
                        std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
                        << std::endl;
                }
            }
        }
        else
        {
Chao Liu's avatar
Chao Liu committed
210
            std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
zjing14's avatar
zjing14 committed
211
212
213
214
        }
    }

    std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
Chao Liu's avatar
Chao Liu committed
215
              << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
216
217

    return pass;
zjing14's avatar
zjing14 committed
218
219
220
221
}

} // namespace profiler
} // namespace ck