main.cpp 8.33 KB
Newer Older
ltqin's avatar
ltqin committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "host_gemm.hpp"
#include "tensor_layout.hpp"
Chao Liu's avatar
Chao Liu committed
13
#include "device_gemm_xdl_splitk.hpp"
ltqin's avatar
ltqin committed
14
15
16
17
18
19
20
21

enum GemmMatrixLayout
{
    MK_KN_MN, // 0
    MK_NK_MN, // 1
    KM_KN_MN, // 2
    KM_NK_MN, // 3
};
Chao Liu's avatar
format  
Chao Liu committed
22

ltqin's avatar
ltqin committed
23
24
25
26
using DeviceGemmNoOpPtr =
    ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
                                                ck::tensor_operation::element_wise::PassThrough,
                                                ck::tensor_operation::element_wise::PassThrough>;
ltqin's avatar
ltqin committed
27

ltqin's avatar
ltqin committed
28
29
30
static std::vector<std::vector<bool>>& GetLayoutType()
{
    static std::vector<std::vector<bool>> LayOut = {{0, 0, 0}, {0, 1, 0}, {1, 0, 0}, {1, 1, 0}};
ltqin's avatar
ltqin committed
31
    return LayOut;
ltqin's avatar
ltqin committed
32
}
Chao Liu's avatar
format  
Chao Liu committed
33

Chao Liu's avatar
Chao Liu committed
34
#if 0
Chao Liu's avatar
format  
Chao Liu committed
35
static void add_device_gemm_instance_mk_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
ltqin's avatar
ltqin committed
36
{
ltqin's avatar
ltqin committed
37
    ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
ltqin's avatar
ltqin committed
38
39
40
41
42
43
44
        float,
        float,
        float,
        ck::tensor_layout::gemm::RowMajor,
        ck::tensor_layout::gemm::RowMajor,
        ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
Chao Liu's avatar
format  
Chao Liu committed
45
46

static void add_device_gemm_instance_mk_nk_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
ltqin's avatar
ltqin committed
47
{
ltqin's avatar
ltqin committed
48
    ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
ltqin's avatar
ltqin committed
49
50
51
52
53
54
55
        float,
        float,
        float,
        ck::tensor_layout::gemm::RowMajor,
        ck::tensor_layout::gemm::ColumnMajor,
        ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
Chao Liu's avatar
format  
Chao Liu committed
56
static void add_device_gemm_instance_km_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
ltqin's avatar
ltqin committed
57
{
ltqin's avatar
ltqin committed
58
    ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
ltqin's avatar
ltqin committed
59
60
61
62
63
64
65
        float,
        float,
        float,
        ck::tensor_layout::gemm::ColumnMajor,
        ck::tensor_layout::gemm::RowMajor,
        ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
Chao Liu's avatar
format  
Chao Liu committed
66
static void add_device_gemm_instance_km_nk_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
ltqin's avatar
ltqin committed
67
{
ltqin's avatar
ltqin committed
68
    ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
ltqin's avatar
ltqin committed
69
70
71
72
73
74
75
76
        float,
        float,
        float,
        ck::tensor_layout::gemm::ColumnMajor,
        ck::tensor_layout::gemm::ColumnMajor,
        ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}

ltqin's avatar
ltqin committed
77
78
static auto& GetAddDeviceGemmInstance()
{
Chao Liu's avatar
format  
Chao Liu committed
79
    static std::vector<void (*)(std::vector<DeviceGemmNoOpPtr>&)> AddDeviceGemmInstance = {
ltqin's avatar
ltqin committed
80
81
82
83
        add_device_gemm_instance_mk_kn_mn,
        add_device_gemm_instance_mk_nk_mn,
        add_device_gemm_instance_km_kn_mn,
        add_device_gemm_instance_km_nk_mn};
ltqin's avatar
ltqin committed
84
85
    return AddDeviceGemmInstance;
}
Chao Liu's avatar
Chao Liu committed
86
87
88
89
90
91
#else
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
#endif
ltqin's avatar
ltqin committed
92

Chao Liu's avatar
format  
Chao Liu committed
93
static void add_device_gemm_instance(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs, int layout)
ltqin's avatar
ltqin committed
94
{
Chao Liu's avatar
Chao Liu committed
95
#if 0
ltqin's avatar
ltqin committed
96
    GetAddDeviceGemmInstance()[layout](gemm_ptrs);
Chao Liu's avatar
Chao Liu committed
97
98
99
100
101
102
#else
    if(layout == 2)
    {
        add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
    }
#endif
ltqin's avatar
ltqin committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
}

template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
    float max_diff = 1e-6;

    for(int i = 0; i < ref.mData.size(); ++i)
    {
        float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
        if(max_diff < diff)
        {
            return false;
        }
    }

    return true;
}
Chao Liu's avatar
format  
Chao Liu committed
121

ltqin's avatar
ltqin committed
122
123
int main(int argc, char* argv[])
{
ltqin's avatar
ltqin committed
124
    if(argc != 9)
ltqin's avatar
ltqin committed
125
126
127
128
129
    {
        printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
        printf("                     1: A[m, k] * B[n, k] = C[m, n];\n");
        printf("                     2: A[k, n] * B[k, n] = C[m, n];\n");
        printf("                     3: A[k, n] * B[n, k] = C[m, n])\n");
ltqin's avatar
ltqin committed
130
        printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
ltqin's avatar
ltqin committed
131
132
133
134
135
136
137
138
139
        return 1;
    }

    const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));

    const int M = std::stoi(argv[2]);
    const int N = std::stoi(argv[3]);
    const int K = std::stoi(argv[4]);

ltqin's avatar
ltqin committed
140
141
142
143
    const int StrideA = std::stoi(argv[5]);
    const int StrideB = std::stoi(argv[6]);
    const int StrideC = std::stoi(argv[7]);
    const int KBatch  = std::stoi(argv[8]);
ltqin's avatar
ltqin committed
144
145
146
147
148
149

    if(layout > 3 || layout < 0)
    {
        printf("arg1 must be 0 ,1 ,2 or 3 \n");
        return 1;
    }
ltqin's avatar
ltqin committed
150
151
    auto LayOut = GetLayoutType();

ltqin's avatar
ltqin committed
152
153
154
155
156
157
158
159
160
161
162
163
164
    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, bool isRevert) {
            if(isRevert)
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({1, stride}));
            }
            else
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({stride, 1}));
            }
        };
Chao Liu's avatar
Chao Liu committed
165

ltqin's avatar
ltqin committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    Tensor<float> a_m_k(f_host_tensor_descriptor(M, K, StrideA, LayOut[layout][0]));
    Tensor<float> b_k_n(f_host_tensor_descriptor(K, N, StrideB, LayOut[layout][1]));
    Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, LayOut[layout][2]));
    Tensor<float> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, LayOut[layout][2]));

    // init data
    std::size_t num_thread = std::thread::hardware_concurrency();
    a_m_k.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
    b_k_n.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
    // set zero to c_device_buf
    c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<float>{}, num_thread);

    host_gemm_mk_kn_mn(a_m_k,
                       b_k_n,
                       c_m_n_host_result,
                       ck::tensor_operation::element_wise::PassThrough{},
                       ck::tensor_operation::element_wise::PassThrough{},
                       ck::tensor_operation::element_wise::PassThrough{});

    DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpace());
    DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpace());
    DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpace());

    a_device_buf.ToDevice(a_m_k.mData.data());
    b_device_buf.ToDevice(b_k_n.mData.data());
    c_device_buf.ToDevice(c_m_n_device_result.mData.data());

    // add device GEMM instances
Chao Liu's avatar
format  
Chao Liu committed
194
    std::vector<DeviceGemmNoOpPtr> gemm_ptrs;
ltqin's avatar
ltqin committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    add_device_gemm_instance(gemm_ptrs, layout);

    bool success = false;
    for(auto& gemm_ptr : gemm_ptrs)
    {
        auto argument_ptr =
            gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
                                          static_cast<float*>(b_device_buf.GetDeviceBuffer()),
                                          static_cast<float*>(c_device_buf.GetDeviceBuffer()),
                                          M,
                                          N,
                                          K,
                                          StrideA,
                                          StrideB,
                                          StrideC,
                                          ck::tensor_operation::element_wise::PassThrough{},
                                          ck::tensor_operation::element_wise::PassThrough{},
ltqin's avatar
ltqin committed
212
                                          ck::tensor_operation::element_wise::PassThrough{},
ltqin's avatar
ltqin committed
213
                                          KBatch);
ltqin's avatar
ltqin committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
        if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
        {
            invoker_ptr->Run(argument_ptr.get(), 0);

            c_device_buf.FromDevice(c_m_n_device_result.mData.data());
            if(!check_out(c_m_n_host_result, c_m_n_device_result))
            {
                success = false;
                break;
            }
            success = true;
        }
    }
Chao Liu's avatar
Chao Liu committed
229

ltqin's avatar
ltqin committed
230
231
232
233
234
235
236
237
238
    if(success)
    {
        std::cout << "test split k : Pass" << std::endl;
    }
    else
    {
        std::cout << "test split k: Fail " << std::endl;
    }
    return 0;
ltqin's avatar
ltqin committed
239
}