run_gemm_example.inc 9.17 KB
Newer Older
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4
5

#pragma once

6
7
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
8
9
10
11
12
13
14
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
    static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif

    using namespace ck::literals;

15
16
17
18
19
20
    auto M       = problem_size.M;
    auto N       = problem_size.N;
    auto K       = problem_size.K;
    auto StrideA = problem_size.StrideA;
    auto StrideB = problem_size.StrideB;
    auto StrideC = problem_size.StrideC;
21
22
23
24
25
26
27
28
29
30
31
32
33

    auto f_host_tensor_descriptor =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
            {
                return HostTensorDescriptor({row, col}, {stride, 1_uz});
            }
            else
            {
                return HostTensorDescriptor({row, col}, {1_uz, stride});
            }
        };

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    auto f_get_default_stride =
        [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
            if(stride == 0)
            {
                // give a chance if stride is zero, return a default packed stride
                if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
                {
                    return col;
                }
                else
                {
                    return row;
                }
            }
            else
                return stride;
        };

    StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
    StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
    StrideC = f_get_default_stride(M, N, StrideC, CLayout{});

56
57
58
59
60
    Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
    Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));

    switch(config.init_method)
    {
61
62
63
64
    case 0:
        ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
        ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
        break;
65
    case 1:
66
67
        ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
        ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
68
        break;
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    case 2:
        ck::utils::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
        ck::utils::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
        break;
    case 3:
        ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
        ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
        break;
    case 4:
        ck::utils::FillUniformDistribution<ADataType>{0.0f, 0.1f}(a_m_k);
        ck::utils::FillUniformDistribution<BDataType>{-0.01f, 0.01f}(b_k_n);
        break;
    case 5:
        ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
        ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
        break;
    case 6:
        ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
        ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
        break;
89
    default:
90
91
        ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
        ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
92
93
94
    }

    Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
95
    Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
96
97
98
99
100
101

    std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
    std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
    std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;

#ifdef BUILD_INT4_EXAMPLE
102
103
104
105
106
    DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
    DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
    DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
                               c_m_n_device_result.mDesc.GetElementSpaceSize());

107
108
109
110
111
112
    const Tensor<KernelADataType> a_m_k_converted(a_m_k);
    const Tensor<KernelBDataType> b_k_n_converted(b_k_n);

    a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
    b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
113
114
115
116
    DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
    DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
    DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());

117
118
119
120
121
122
123
124
    a_m_k_device_buf.ToDevice(a_m_k.mData.data());
    b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif

    auto a_element_op = AElementOp{};
    auto b_element_op = BElementOp{};
    auto c_element_op = CElementOp{};

125
126
127
    float best_perf         = .0;
    float best_time         = .0;
    std::string best_kernel = "";
128

129
130
    ck::static_for<0, std::tuple_size_v<DeviceGemmFactory>, 1>{}([&](auto i) -> void {
        const auto device_gemm_instance = std::get<i>(DeviceGemmFactory{});
131

132
133
134
135
136
        using DeviceGemmInstance = ck::remove_cvref_t<decltype(device_gemm_instance)>;
        // do GEMM
        auto gemm      = DeviceGemmInstance{};
        auto invoker   = gemm.MakeInvoker();
        float ave_time = 0;
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

        auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
            static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
            static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
            static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
            static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
            static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
            static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
            M,
            N,
            K,
            StrideA,
            StrideB,
            StrideC,
            a_element_op,
            b_element_op,
156
157
            c_element_op);
#if 0
158
159
160
161
162
163
164
        if(!gemm.IsSupportedArgument(argument))
        {
            std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;

            return true;
        }
#endif
165
        ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 300, 3000});
166

167
168
169
        std::size_t flop = 2_uz * M * N * K;
        std::size_t num_btype =
            sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
170

171
172
173
174
175
176
177
        float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
        if(tflops > best_perf)
        {
            best_perf   = tflops;
            best_time   = ave_time;
            best_kernel = gemm.GetTypeString();
        }
178

179
180
        float gb_per_sec = num_btype / 1.E6 / ave_time;

181
182
183
184
185
        if(config.time_kernel==1)
        std::cout << "Perf mode Mean: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
                  << " GB/s, " << gemm.GetTypeString() << std::endl;
        else if(config.time_kernel==2)
        std::cout << "Perf mode Median: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
                  << " GB/s, " << gemm.GetTypeString() << std::endl;
    });

    std::cout << "---------------------------------------------------------------------------------"
                 "-----------"
              << std::endl;
    std::cout << "Problem Size: M: " << M << ", N: " << N << ", K: " << K << std::endl;
    std::cout << "---------------------------------------------------------------------------------"
                 "-----------"
              << std::endl;
    std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
              << " ms" << std::endl;
    std::cout << "---------------------------------------------------------------------------------"
                 "-----------"
              << std::endl;
201
202
203
204
205
206
207
208
209
210
211
212

    if(config.do_verification)
    {
        auto ref_gemm    = ReferenceGemmInstance{};
        auto ref_invoker = ref_gemm.MakeInvoker();

        auto ref_argument = ref_gemm.MakeArgument(
            a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);

        ref_invoker.Run(ref_argument);

#ifdef BUILD_INT4_EXAMPLE
213
214
215
216
217
        Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);

        c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());

        c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
218

219
        return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
220
#else
221
222
        c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());

223
        return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
224
225
226
227
228
229
230
231
232
233
234
235
236
#endif
    }

    return true;
}

bool run_gemm_example(int argc, char* argv[])
{
    ProblemSize problem_size;
    ExecutionConfig config;

    return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
237
238
239
240
241
242
243
244

bool run_gemm_streamk_example(int argc, char* argv[])
{
    ProblemSizeStreamK problem_size;
    ExecutionConfig config;

    return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}