"include/vscode:/vscode.git/clone" did not exist on "ee060994463e472cf64010d87f5ba3fc3b0c624e"
gemm_split_k.cpp 8.78 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

ltqin's avatar
ltqin committed
4
#include <cstdlib>
5
6
#include <initializer_list>
#include <iostream>
Chao Liu's avatar
Chao Liu committed
7
8
9

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
10
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
Chao Liu's avatar
Chao Liu committed
11
12
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

13
14
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"

15
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
16
#include "ck/library/utility/check_err.hpp"
17
#include "ck/library/utility/device_memory.hpp"
18
#include "ck/library/utility/host_gemm.hpp"
19
20
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
21
#include "ck/library/utility/literals.hpp"
ltqin's avatar
ltqin committed
22

Chao Liu's avatar
Chao Liu committed
23
enum struct GemmMatrixLayout
ltqin's avatar
ltqin committed
24
25
26
27
28
29
30
31
{
    MK_KN_MN, // 0
    MK_NK_MN, // 1
    KM_KN_MN, // 2
    KM_NK_MN, // 3
};

template <typename T>
32
33
static std::enable_if_t<std::is_convertible_v<T, double>, bool> check_out(const Tensor<T>& ref,
                                                                          const Tensor<T>& out)
ltqin's avatar
ltqin committed
34
{
35
36
37
38
39
    if(out.size() != ref.size())
    {
        return false;
    }
    constexpr float max_diff = 1e-6;
ltqin's avatar
ltqin committed
40

41
42
    auto o = out.begin();
    for(auto r = ref.begin(); r != ref.end(); ++r, ++o)
ltqin's avatar
ltqin committed
43
    {
44
        const float diff = std::abs(double(*r) - double(*o));
ltqin's avatar
ltqin committed
45
46
47
48
49
50
51
52
53
        if(max_diff < diff)
        {
            return false;
        }
    }

    return true;
}

54
struct gemmArgs
ltqin's avatar
ltqin committed
55
{
Chao Liu's avatar
Chao Liu committed
56
    GemmMatrixLayout layout;
57
58
59
60
61
62
63
64
    int M;
    int N;
    int K;
    int StrideA;
    int StrideB;
    int StrideC;
    int KBatch;
};
ltqin's avatar
ltqin committed
65

66
67
int test_gemm(const gemmArgs& args)
{
68
69
70
71
72
    using Row = ck::tensor_layout::gemm::RowMajor;
    using Col = ck::tensor_layout::gemm::ColumnMajor;

    using PassThrough = ck::tensor_operation::element_wise::PassThrough;

ltqin's avatar
ltqin committed
73
74
    bool a_row_major, b_row_major, c_row_major;

75
    switch(args.layout)
ltqin's avatar
ltqin committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    {
    case GemmMatrixLayout::MK_KN_MN:
        a_row_major = true;
        b_row_major = true;
        c_row_major = true;
        break;
    case GemmMatrixLayout::MK_NK_MN:
        a_row_major = true;
        b_row_major = false;
        c_row_major = true;
        break;
    case GemmMatrixLayout::KM_KN_MN:
        a_row_major = false;
        b_row_major = true;
        c_row_major = true;
        break;
    case GemmMatrixLayout::KM_NK_MN:
        a_row_major = false;
        b_row_major = false;
        c_row_major = true;
        break;
    default: printf("not supported layout"); return 1;
    }

100
101
    using namespace ck::literals;

ltqin's avatar
ltqin committed
102
103
104
105
    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, bool row_major) {
            if(row_major)
            {
106
                return HostTensorDescriptor({row, col}, {stride, 1_uz});
ltqin's avatar
ltqin committed
107
108
109
            }
            else
            {
110
                return HostTensorDescriptor({row, col}, {1_uz, stride});
ltqin's avatar
ltqin committed
111
112
113
            }
        };

114
115
    Tensor<float> a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major));
    Tensor<float> b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major));
116
117
118
119
    Tensor<float> c_m_n_host_result(
        f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
    Tensor<float> c_m_n_device_result(
        f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
ltqin's avatar
ltqin committed
120
121

    // init data
122
    std::size_t num_thread = 1;
ltqin's avatar
ltqin committed
123
124
125
126
127
128
129
130
131
132
133
134
    a_m_k.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
    b_k_n.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
    // set zero to c_device_buf
    c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<float>{}, num_thread);

    host_gemm_mk_kn_mn(a_m_k,
                       b_k_n,
                       c_m_n_host_result,
                       ck::tensor_operation::element_wise::PassThrough{},
                       ck::tensor_operation::element_wise::PassThrough{},
                       ck::tensor_operation::element_wise::PassThrough{});

135
136
137
    DeviceMem a_device_buf(a_m_k.GetMemorySize());
    DeviceMem b_device_buf(b_k_n.GetMemorySize());
    DeviceMem c_device_buf(c_m_n_device_result.GetMemorySize());
ltqin's avatar
ltqin committed
138

139
140
141
    a_device_buf.ToDevice(a_m_k.data());
    b_device_buf.ToDevice(b_k_n.data());
    c_device_buf.ToDevice(c_m_n_device_result.data());
ltqin's avatar
ltqin committed
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
        bool success = false;

        using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<decltype(a_layout),
                                                                        decltype(b_layout),
                                                                        decltype(c_layout),
                                                                        float,
                                                                        float,
                                                                        float,
                                                                        PassThrough,
                                                                        PassThrough,
                                                                        PassThrough>;

        const auto gemm_ptrs =
            ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
                DeviceOp>::GetInstances();

        for(auto& gemm_ptr : gemm_ptrs)
        {
            auto argument_ptr =
                gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
                                              static_cast<float*>(b_device_buf.GetDeviceBuffer()),
                                              static_cast<float*>(c_device_buf.GetDeviceBuffer()),
                                              args.M,
                                              args.N,
                                              args.K,
                                              args.StrideA,
                                              args.StrideB,
                                              args.StrideC,
                                              ck::tensor_operation::element_wise::PassThrough{},
                                              ck::tensor_operation::element_wise::PassThrough{},
                                              ck::tensor_operation::element_wise::PassThrough{},
                                              args.KBatch);

            auto invoker_ptr = gemm_ptr->MakeInvokerPointer();

            if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
            {
                invoker_ptr->Run(argument_ptr.get());

183
                c_device_buf.FromDevice(c_m_n_device_result.data());
184
185
186
187
188
189
190
191
192
193
194
195
196
197

                if(!check_out(c_m_n_host_result, c_m_n_device_result))
                {
                    success = false;
                    break;
                }
                success = true;
            }
        }

        return success;
    };

    bool success = false;
ltqin's avatar
ltqin committed
198

199
    if(args.layout == GemmMatrixLayout::MK_KN_MN)
ltqin's avatar
ltqin committed
200
    {
201
        success = test(Row{}, Row{}, Row{});
ltqin's avatar
ltqin committed
202
    }
203
    else if(args.layout == GemmMatrixLayout::MK_NK_MN)
ltqin's avatar
ltqin committed
204
    {
205
        success = test(Row{}, Col{}, Row{});
ltqin's avatar
ltqin committed
206
    }
207
    else if(args.layout == GemmMatrixLayout::KM_KN_MN)
ltqin's avatar
ltqin committed
208
    {
209
        success = test(Col{}, Row{}, Row{});
ltqin's avatar
ltqin committed
210
211
212
    }
    else
    {
213
        success = test(Col{}, Col{}, Row{});
ltqin's avatar
ltqin committed
214
215
    }

216
    auto error_code = 0;
ltqin's avatar
ltqin committed
217
218
219
220
221
222
223
    if(success)
    {
        std::cout << "test split k : Pass" << std::endl;
    }
    else
    {
        std::cout << "test split k: Fail " << std::endl;
224
        error_code = -1; // test needs to report failure
225
226
227
228
229
230
231
232
233
    }
    return error_code;
}

int main(int argc, char* argv[])
{
    std::vector<gemmArgs> test_cases;
    if(argc == 1)
    {
Chao Liu's avatar
Chao Liu committed
234
        test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}};
235
236
237
238
239
        // JD: Populate with more and meaningful
        return 0;
    }
    else if(argc == 9)
    {
Chao Liu's avatar
Chao Liu committed
240
        const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
241

242
243
244
        const int M = std::stoi(argv[2]);
        const int N = std::stoi(argv[3]);
        const int K = std::stoi(argv[4]);
245

246
247
248
249
250
        const int StrideA = std::stoi(argv[5]);
        const int StrideB = std::stoi(argv[6]);
        const int StrideC = std::stoi(argv[7]);
        const int KBatch  = std::stoi(argv[8]);
        test_cases        = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
251
252
253
254
255
256
257
258
259
260
    }
    else
    {
        printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
        printf("                     1: A[m, k] * B[n, k] = C[m, n];\n");
        printf("                     2: A[k, m] * B[k, n] = C[m, n];\n");
        printf("                     3: A[k, m] * B[n, k] = C[m, n])\n");
        printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
        return -1;
    }
261
    for(const auto& kinder : test_cases)
262
263
264
    {
        const auto res = test_gemm(kinder);
        if(!res)
265
            return -1;
ltqin's avatar
ltqin committed
266
267
268
    }
    return 0;
}