gemm_gemm_xdl_fp16.cpp 11.4 KB
Newer Older
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Anthony Chang's avatar
Anthony Chang committed
4
5
6
7
8
9
10
11
/*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
                                              |------------|
                                                   Gemm0
                                              |---------------------|
                                                       Gemm1
*/

12
13
14
15
16
17
18
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
Anthony Chang's avatar
Anthony Chang committed
19
#include "ck/tensor_operation/gpu/device/device_gemm_gemm_xdl_cshuffle.hpp"
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ADataType        = F16;
Anthony Chang's avatar
Anthony Chang committed
40
41
using B0DataType       = F16;
using B1DataType       = F16;
42
43
44
45
using AccDataType      = F32;
using CShuffleDataType = F32;
using CDataType        = F16;

Anthony Chang's avatar
Anthony Chang committed
46
47
48
49
using ALayout  = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout  = Row;
50
51
52
53
54
55
56

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

Anthony Chang's avatar
Anthony Chang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmGemm_Xdl_CShuffle<
    ALayout,
    B0Layout,
    B1Layout,
    CLayout,
    ADataType,
    B0DataType,
    CDataType,
    AccDataType,
    CShuffleDataType,
    AElementOp,
    BElementOp,
    CElementOp,
    GemmDefault,
    1,
    256,
    128,         // MPerBlock
    128,         // NPerBlock
    32,          // KPerBlock
    128,          // Gemm1NPerBlock
    32,          // Gemm1KPerBlock
    8,           // AK1
    8,           // BK1
    2,           // B1K1
    32,          // MPerXDL
    32,          // NPerXDL
    1,           // MXdlPerWave
    4,           // NXdlPerWave
    4,           // Gemm1NXdlPerWave
    S<4, 64, 1>, // ABlockTransfer
    S<1, 0, 2>,
    S<1, 0, 2>,
    2,
    8,
    8,
    true,
    S<4, 64, 1>, // BBlockTransfer
    S<1, 0, 2>,
    S<1, 0, 2>,
    2,
    8,
    8,
    true,
    S<8, 32, 1>, // B1BlockTransfer
    S<0, 2, 1>,
    S<0, 2, 1>,
    1,
    4,
    2,
    false,
    1,              // CShuffleMXdlPerWavePerShuffle
    2,              // CShuffleNXdlPerWavePerShuffle
    S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
    8>;             // CShuffleBlockTransferScalarPerVector_NPerBlock

using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceGemm<ADataType,
                                                                         B0DataType,
                                                                         AccDataType,
                                                                         AccDataType,
                                                                         AElementOp,
                                                                         BElementOp,
                                                                         CElementOp>;
Anthony Chang's avatar
Anthony Chang committed
119
120
using ReferenceGemm1Instance = ck::tensor_operation::host::
    ReferenceGemm<AccDataType, B1DataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
121
122
123
124

int main(int argc, char* argv[])
{
    bool do_verification = true;
Anthony Chang's avatar
Anthony Chang committed
125
    int init_method      = 1;
126
127
128
    bool time_kernel     = false;

    // GEMM shape
Anthony Chang's avatar
Anthony Chang committed
129
130
131
132
133
134
135
136
137
138
139
    // ck::index_t M = 1024;
    // ck::index_t N = 1024;
    // ck::index_t K = 64;
    // ck::index_t O = 64;

    // ck::index_t StrideA = 1024;
    // ck::index_t StrideB0 = 1024;
    // ck::index_t StrideB1 = 1024;
    // ck::index_t StrideC = 1024;

    ck::index_t M = 256;
Anthony Chang's avatar
Anthony Chang committed
140
    ck::index_t N = 128;
Anthony Chang's avatar
Anthony Chang committed
141
    ck::index_t K = 32;
Anthony Chang's avatar
Anthony Chang committed
142
143
144
145
146
    ck::index_t O = 128;
    ck::index_t StrideA = 32;
    ck::index_t StrideB0 = 32;
    ck::index_t StrideB1 = 128;
    ck::index_t StrideC = 128;
147
148
149
150
151
152
153
154
155
156
157

    if(argc == 1)
    {
        // use default case
    }
    else if(argc == 4)
    {
        do_verification = std::stoi(argv[1]);
        init_method     = std::stoi(argv[2]);
        time_kernel     = std::stoi(argv[3]);
    }
Anthony Chang's avatar
Anthony Chang committed
158
    else if(argc == 12)
159
160
161
162
163
164
165
166
    {
        do_verification = std::stoi(argv[1]);
        init_method     = std::stoi(argv[2]);
        time_kernel     = std::stoi(argv[3]);

        M = std::stoi(argv[4]);
        N = std::stoi(argv[5]);
        K = std::stoi(argv[6]);
Anthony Chang's avatar
Anthony Chang committed
167
        O = std::stoi(argv[7]);
168

Anthony Chang's avatar
Anthony Chang committed
169
170
171
172
        StrideA = std::stoi(argv[8]);
        StrideB0 = std::stoi(argv[9]);
        StrideB1 = std::stoi(argv[10]);
        StrideC = std::stoi(argv[11]);
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    }
    else
    {
        printf("arg1: verification (0=no, 1=yes)\n");
        printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
        printf("arg3: time kernel (0=no, 1=yes)\n");
        printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
        exit(0);
    }

    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({stride, 1}));
            }
            else
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({1, stride}));
            }
        };

Anthony Chang's avatar
Anthony Chang committed
197
    // C_m_o = A_m_k * B0_k_n * B1_n_o
198
    Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Anthony Chang's avatar
Anthony Chang committed
199
200
201
202
    Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB0, B0Layout{}));
    Tensor<B1DataType> b1_n_o(f_host_tensor_descriptor(N, O, StrideB1, B1Layout{}));
    Tensor<CDataType> c_m_o_host_result(f_host_tensor_descriptor(N, O, StrideC, CLayout{}));
    Tensor<CDataType> c_m_o_device_result(f_host_tensor_descriptor(N, O, StrideC, CLayout{}));
203
204

    std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
Anthony Chang's avatar
Anthony Chang committed
205
206
207
    std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
    std::cout << "b1_n_o: " << b1_n_o.mDesc << std::endl;
    std::cout << "c_m_o: " << c_m_o_host_result.mDesc << std::endl;
208
209
210
211
212
213

    switch(init_method)
    {
    case 0: break;
    case 1:
        a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
Anthony Chang's avatar
Anthony Chang committed
214
215
        b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
        b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
216
217
        break;
    case 2:
Anthony Chang's avatar
Anthony Chang committed
218
219
220
        a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
        b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
        b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
221
222
        break;
    default:
Anthony Chang's avatar
Anthony Chang committed
223
224
225
        a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
        // b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
        b0_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
Anthony Chang's avatar
Anthony Chang committed
226
        b1_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
Anthony Chang's avatar
Anthony Chang committed
227
        // b1_n_o.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
228
229
230
    }

    DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
Anthony Chang's avatar
Anthony Chang committed
231
232
233
    DeviceMem b0_k_n_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpace());
    DeviceMem b1_n_o_device_buf(sizeof(B1DataType) * b1_n_o.mDesc.GetElementSpace());
    DeviceMem c_m_o_device_buf(sizeof(CDataType) * c_m_o_device_result.mDesc.GetElementSpace());
234
235

    a_m_k_device_buf.ToDevice(a_m_k.mData.data());
Anthony Chang's avatar
Anthony Chang committed
236
    b0_k_n_device_buf.ToDevice(b0_k_n.mData.data());
Anthony Chang's avatar
Anthony Chang committed
237
    b1_n_o_device_buf.ToDevice(b1_n_o.mData.data());
238
239
240
241
242
243
244
245
246

    auto a_element_op = AElementOp{};
    auto b_element_op = BElementOp{};
    auto c_element_op = CElementOp{};

    // do GEMM
    auto gemm     = DeviceGemmInstance{};
    auto invoker  = gemm.MakeInvoker();
    auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
Anthony Chang's avatar
Anthony Chang committed
247
                                      static_cast<B0DataType*>(b0_k_n_device_buf.GetDeviceBuffer()),
Anthony Chang's avatar
Anthony Chang committed
248
                                      static_cast<B1DataType*>(b1_n_o_device_buf.GetDeviceBuffer()),
Anthony Chang's avatar
Anthony Chang committed
249
                                      static_cast<CDataType*>(c_m_o_device_buf.GetDeviceBuffer()),
250
251
252
                                      M,
                                      N,
                                      K,
Anthony Chang's avatar
Anthony Chang committed
253
                                      O,
254
                                      StrideA,
Anthony Chang's avatar
Anthony Chang committed
255
                                      StrideB0,
Anthony Chang's avatar
Anthony Chang committed
256
                                      StrideB1,
257
258
259
260
261
262
263
264
265
266
267
268
269
270
                                      StrideC,
                                      a_element_op,
                                      b_element_op,
                                      c_element_op);

    if(!gemm.IsSupportedArgument(argument))
    {
        std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;

        return 0;
    }

    float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});

Anthony Chang's avatar
Anthony Chang committed
271
    std::size_t flop      = (size_t)M * N * K * 2 + (size_t)M * N * O * 2;
Anthony Chang's avatar
Anthony Chang committed
272
273
    std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
                            sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O;
274
275
276
277
278
279
280
281

    float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

    float gb_per_sec = num_btype / 1.E6 / ave_time;

    std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
              << gemm.GetTypeString() << std::endl;

Anthony Chang's avatar
Anthony Chang committed
282
    c_m_o_device_buf.FromDevice(c_m_o_device_result.mData.data());
283
284
285

    if(do_verification)
    {
Anthony Chang's avatar
Anthony Chang committed
286
287
288
289
290
291
292
293
294
        // Output of Gemm0 is input A of Gemm1
        Tensor<AccDataType> a1_m_n(f_host_tensor_descriptor(M, N, N, Row{}));

        auto ref_gemm0          = ReferenceGemm0Instance{};
        auto ref_gemm0_invoker  = ref_gemm0.MakeInvoker();
        auto ref_gemm0_argument = ref_gemm0.MakeArgument(
            a_m_k, b0_k_n, a1_m_n, a_element_op, b_element_op, c_element_op);

        ref_gemm0_invoker.Run(ref_gemm0_argument);
295

Anthony Chang's avatar
Anthony Chang committed
296
297
298
299
        auto ref_gemm1          = ReferenceGemm1Instance{};
        auto ref_gemm1_invoker  = ref_gemm1.MakeInvoker();
        auto ref_gemm1_argument = ref_gemm1.MakeArgument(
            a1_m_n, b1_n_o, c_m_o_host_result, a_element_op, b_element_op, c_element_op);
300

Anthony Chang's avatar
Anthony Chang committed
301
        ref_gemm1_invoker.Run(ref_gemm1_argument);
302

Anthony Chang's avatar
Anthony Chang committed
303
304
305
306
307
308
309
310
311
        // LogRangeAsType<float>(std::cout << "a_m_k: ", a_m_k.mData, ",") << std::endl;
        // LogRangeAsType<float>(std::cout << "b0_k_n : ", b0_k_n.mData, ",") << std::endl;
        // LogRangeAsType<float>(std::cout << "b1_n_o : ", b1_n_o.mData, ",") << std::endl;
        // LogRangeAsType<float>(std::cout << "c_m_o_device_result : ", c_m_o_device_result.mData, ",") << std::endl;

        std::cout << "b0_k_n(0, 0) = " << (float)b0_k_n(0, 0) << ", b0_k_n(1, 0) = " << (float)b0_k_n(1, 0)
                  << ", b0_k_n(0, 1) = " << (float)b0_k_n(0, 1) << ", b0_k_n(1, 1) = " << (float)b0_k_n(1, 1)
                  << std::endl;

Anthony Chang's avatar
Anthony Chang committed
312
        return ck::utils::check_err(c_m_o_device_result.mData, c_m_o_host_result.mData) ? 0 : 1;
313
314
315
316
    }

    return 0;
}