gemm_gemm_xdl_fp16.cpp 11.5 KB
Newer Older
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Anthony Chang's avatar
Anthony Chang committed
4
5
6
7
8
9
10
11
/*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
                                              |------------|
                                                   Gemm0
                                              |---------------------|
                                                       Gemm1
*/

12
13
14
15
16
17
18
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
Anthony Chang's avatar
Anthony Chang committed
19
#include "ck/tensor_operation/gpu/device/device_gemm_gemm_xdl_cshuffle.hpp"
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ADataType        = F16;
Anthony Chang's avatar
Anthony Chang committed
40
41
using B0DataType       = F16;
using B1DataType       = F16;
42
43
44
45
using AccDataType      = F32;
using CShuffleDataType = F32;
using CDataType        = F16;

Anthony Chang's avatar
Anthony Chang committed
46
47
48
49
using ALayout  = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout  = Row;
50
51
52
53
54
55
56
57

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

// clang-format off
Anthony Chang's avatar
Anthony Chang committed
58
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmGemm_Xdl_CShuffle
59
60
61
62
//######| ALayout| BLayout| CLayout|     AData|     BData|     CData|     AccData|         CShuffle|           A|           B|           C|           GEMM| NumGemmK| Block|  MPer|  NPer|  KPer| AK1| BK1| MPer| NPer| MXdl| NXdl|  ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|  BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|    CShuffle|    CShuffle| CBlockTransferClusterLengths|  CBlockTransfer|
//######|        |        |        |      Type|      Type|      Type|        Type|         DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch|  Size| Block| Block| Block|    |    |  XDL|  XDL|  Per|  Per|   ThreadCluster|  ThreadCluster| SrcAccessOrder|   SrcVectorDim|      SrcScalar|      DstScalar| AddExtraM|   ThreadCluster|  ThreadCluster| SrcAccessOrder|  SrcVectorDim|      SrcScalar|      DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave|         _MBlock_MWaveMPerXdl| ScalarPerVector|
//######|        |        |        |          |          |          |            |                 |   Operation|   Operation|   Operation|               |    Stage|      |      |      |      |    |    |     |     | Wave| Wave| Lengths_K0_M_K1|   ArrangeOrder|               |               |      PerVector|   PerVector_K1|          | Lengths_K0_N_K1|   ArrangeOrder|               |              |      PerVector|   PerVector_K1|          |  PerShuffle|  PerShuffle|         _NBlock_NWaveNPerXdl|   _NWaveNPerXdl|
//######|        |        |        |          |          |          |            |                 |            |            |            |               |         |      |      |      |      |    |    |     |     |     |     |                |               |               |               |               |               |          |                |               |               |              |               |               |          |            |            |                             |                |
Anthony Chang's avatar
Anthony Chang committed
63
        < ALayout,B0Layout, CLayout, ADataType,B0DataType, CDataType, AccDataType, CShuffleDataType,  AElementOp,  BElementOp,  CElementOp,    GemmDefault,        1,   256,   256,   128,    32,   8,   8,   32,   32,    4,    2,     S<4, 64, 1>,     S<1, 0, 2>,     S<1, 0, 2>,              2,              8,              8,         1,     S<4, 64, 1>,     S<1, 0, 2>,     S<1, 0, 2>,             2,              8,              8,         1,           1,           1,               S<1, 32, 1, 8>,               8>;
64
65
// clang-format on

Anthony Chang's avatar
Anthony Chang committed
66
67
68
69
using ReferenceGemm0Instance = ck::tensor_operation::host::
    ReferenceGemm<ADataType, B0DataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::
    ReferenceGemm<AccDataType, B1DataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
70
71
72
73

int main(int argc, char* argv[])
{
    bool do_verification = true;
Anthony Chang's avatar
Anthony Chang committed
74
75
    // int init_method      = 1;
    int init_method      = 3;
76
77
78
    bool time_kernel     = false;

    // GEMM shape
Anthony Chang's avatar
Anthony Chang committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    // ck::index_t M = 1024;
    // ck::index_t N = 1024;
    // ck::index_t K = 64;
    // ck::index_t O = 64;

    // ck::index_t StrideA = 1024;
    // ck::index_t StrideB0 = 1024;
    // ck::index_t StrideB1 = 1024;
    // ck::index_t StrideC = 1024;

    ck::index_t M = 256;
    ck::index_t N = 256;
    ck::index_t K = 32;
    ck::index_t O = 256;
    ck::index_t StrideA = 256;
    ck::index_t StrideB0 = 256;
    ck::index_t StrideB1 = 256;
    ck::index_t StrideC = 256;
97
98
99
100
101
102
103
104
105
106
107

    if(argc == 1)
    {
        // use default case
    }
    else if(argc == 4)
    {
        do_verification = std::stoi(argv[1]);
        init_method     = std::stoi(argv[2]);
        time_kernel     = std::stoi(argv[3]);
    }
Anthony Chang's avatar
Anthony Chang committed
108
    else if(argc == 12)
109
110
111
112
113
114
115
116
    {
        do_verification = std::stoi(argv[1]);
        init_method     = std::stoi(argv[2]);
        time_kernel     = std::stoi(argv[3]);

        M = std::stoi(argv[4]);
        N = std::stoi(argv[5]);
        K = std::stoi(argv[6]);
Anthony Chang's avatar
Anthony Chang committed
117
        O = std::stoi(argv[7]);
118

Anthony Chang's avatar
Anthony Chang committed
119
120
121
122
        StrideA = std::stoi(argv[8]);
        StrideB0 = std::stoi(argv[9]);
        StrideB1 = std::stoi(argv[10]);
        StrideC = std::stoi(argv[11]);
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    }
    else
    {
        printf("arg1: verification (0=no, 1=yes)\n");
        printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
        printf("arg3: time kernel (0=no, 1=yes)\n");
        printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
        exit(0);
    }

    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({stride, 1}));
            }
            else
            {
                return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
                                            std::vector<std::size_t>({1, stride}));
            }
        };

Anthony Chang's avatar
Anthony Chang committed
147
    // C_m_o = A_m_k * B0_k_n * B1_n_o
148
    Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Anthony Chang's avatar
Anthony Chang committed
149
150
151
152
    Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB0, B0Layout{}));
    Tensor<B1DataType> b1_n_o(f_host_tensor_descriptor(N, O, StrideB1, B1Layout{}));
    Tensor<CDataType> c_m_o_host_result(f_host_tensor_descriptor(N, O, StrideC, CLayout{}));
    Tensor<CDataType> c_m_o_device_result(f_host_tensor_descriptor(N, O, StrideC, CLayout{}));
153
154

    std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
Anthony Chang's avatar
Anthony Chang committed
155
156
157
    std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
    std::cout << "b1_n_o: " << b1_n_o.mDesc << std::endl;
    std::cout << "c_m_o: " << c_m_o_host_result.mDesc << std::endl;
158
159
160
161
162
163

    switch(init_method)
    {
    case 0: break;
    case 1:
        a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
Anthony Chang's avatar
Anthony Chang committed
164
165
        b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
        b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
166
167
168
        break;
    case 2:
        a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
Anthony Chang's avatar
Anthony Chang committed
169
170
        b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
        b1_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
171
172
173
        break;
    default:
        a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
Anthony Chang's avatar
Anthony Chang committed
174
175
        b0_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
        b1_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
176
177
178
    }

    DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
Anthony Chang's avatar
Anthony Chang committed
179
180
181
    DeviceMem b0_k_n_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpace());
    DeviceMem b1_n_o_device_buf(sizeof(B1DataType) * b1_n_o.mDesc.GetElementSpace());
    DeviceMem c_m_o_device_buf(sizeof(CDataType) * c_m_o_device_result.mDesc.GetElementSpace());
182
183

    a_m_k_device_buf.ToDevice(a_m_k.mData.data());
Anthony Chang's avatar
Anthony Chang committed
184
    b0_k_n_device_buf.ToDevice(b0_k_n.mData.data());
185
186
187
188
189
190
191
192
193

    auto a_element_op = AElementOp{};
    auto b_element_op = BElementOp{};
    auto c_element_op = CElementOp{};

    // do GEMM
    auto gemm     = DeviceGemmInstance{};
    auto invoker  = gemm.MakeInvoker();
    auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
Anthony Chang's avatar
Anthony Chang committed
194
195
                                      static_cast<B0DataType*>(b0_k_n_device_buf.GetDeviceBuffer()),
                                      static_cast<CDataType*>(c_m_o_device_buf.GetDeviceBuffer()),
196
197
198
199
                                      M,
                                      N,
                                      K,
                                      StrideA,
Anthony Chang's avatar
Anthony Chang committed
200
                                      StrideB0,
201
202
203
204
205
206
207
208
209
210
211
212
213
214
                                      StrideC,
                                      a_element_op,
                                      b_element_op,
                                      c_element_op);

    if(!gemm.IsSupportedArgument(argument))
    {
        std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;

        return 0;
    }

    float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});

Anthony Chang's avatar
Anthony Chang committed
215
216
217
    std::size_t flop      = std::size_t(2) * (M * N * K + M * N * O);
    std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
                            sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O;
218
219
220
221
222
223
224
225

    float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

    float gb_per_sec = num_btype / 1.E6 / ave_time;

    std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
              << gemm.GetTypeString() << std::endl;

Anthony Chang's avatar
Anthony Chang committed
226
    c_m_o_device_buf.FromDevice(c_m_o_device_result.mData.data());
227
228
229

    if(do_verification)
    {
Anthony Chang's avatar
Anthony Chang committed
230
231
232
233
234
235
236
237
238
        // Output of Gemm0 is input A of Gemm1
        Tensor<AccDataType> a1_m_n(f_host_tensor_descriptor(M, N, N, Row{}));

        auto ref_gemm0          = ReferenceGemm0Instance{};
        auto ref_gemm0_invoker  = ref_gemm0.MakeInvoker();
        auto ref_gemm0_argument = ref_gemm0.MakeArgument(
            a_m_k, b0_k_n, a1_m_n, a_element_op, b_element_op, c_element_op);

        ref_gemm0_invoker.Run(ref_gemm0_argument);
239

Anthony Chang's avatar
Anthony Chang committed
240
241
242
243
        auto ref_gemm1          = ReferenceGemm1Instance{};
        auto ref_gemm1_invoker  = ref_gemm1.MakeInvoker();
        auto ref_gemm1_argument = ref_gemm1.MakeArgument(
            a1_m_n, b1_n_o, c_m_o_host_result, a_element_op, b_element_op, c_element_op);
244

Anthony Chang's avatar
Anthony Chang committed
245
        ref_gemm1_invoker.Run(ref_gemm1_argument);
246

Anthony Chang's avatar
Anthony Chang committed
247
        return ck::utils::check_err(c_m_o_device_result.mData, c_m_o_host_result.mData) ? 0 : 1;
248
249
250
251
    }

    return 0;
}