flush_cache.hpp 13 KB
Newer Older
ltqin's avatar
ltqin committed
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <hip/hip_runtime.h>
#include <set>
8
#include <vector>
ltqin's avatar
ltqin committed
9
10
11
12
13
14
15
16

#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/flush_icache.hpp"
namespace ck {
namespace utility {

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
template <typename Argument, typename DsDataType>
struct RotatingMemWrapperMultiD
{
    static constexpr index_t NumDs = DsDataType::Size();

    using ADataType     = decltype(Argument::p_a_grid);
    using BDataType     = decltype(Argument::p_b_grid);
    using DsGridPointer = decltype(Argument::p_ds_grid);

    RotatingMemWrapperMultiD() = delete;
    RotatingMemWrapperMultiD(Argument& arg_,
                             std::size_t rotating_count_,
                             std::size_t size_a_,
                             std::size_t size_b_,
                             std::array<std::size_t, NumDs> size_ds_)
        : arg(arg_),
          rotating_count(rotating_count_),
          size_a(size_a_),
          size_b(size_b_),
          size_ds(size_ds_)
    {
        p_a_grids.push_back(arg.p_a_grid);
        p_b_grids.push_back(arg.p_b_grid);
        p_ds_grids.push_back(arg.p_ds_grid);
        for(size_t i = 1; i < rotating_count; i++)
        {
            {
                void* pADeviceBuf;
                hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
                hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
                                          const_cast<void*>(p_a_grids[0]),
                                          size_a_,
                                          hipMemcpyDeviceToDevice));
                p_a_grids.push_back(pADeviceBuf);
            }

            {
                void* pBDeviceBuf;
                hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
                hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
                                          const_cast<void*>(p_b_grids[0]),
                                          size_b_,
                                          hipMemcpyDeviceToDevice));
                p_b_grids.push_back(pBDeviceBuf);
            }

            {

                DsGridPointer ds_buffer;
                static_for<0, NumDs, 1>{}([&](auto j) {
                    void* pDDeviceBuf;
                    hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
                    hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
                                              static_cast<const void*>(p_ds_grids[0][j]),
                                              size_ds_[j],
                                              hipMemcpyDeviceToDevice));

                    using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;

                    ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
                });

                p_ds_grids.push_back(ds_buffer);
            }
        }
    }

    void Next()
    {
        if(rotating_count > 1)
        {
            std::size_t idx = iter++ % rotating_count;
            arg.p_a_grid    = reinterpret_cast<ADataType>(p_a_grids[idx]);
            arg.p_b_grid    = reinterpret_cast<BDataType>(p_b_grids[idx]);
            arg.p_ds_grid   = p_ds_grids[idx];
        }
    }
    void Print()
    {
aska-0096's avatar
aska-0096 committed
96
97
98
99
100
        std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b;
        static_for<0, NumDs, 1>{}([&](auto j) {
            std::cout << ", size_d" <<j.value<<": "<< size_ds[j];
        });
        std::cout << ", rotating_count: " << rotating_count << "}" << std::endl;
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    }
    ~RotatingMemWrapperMultiD()
    {
        if(rotating_count > 1)
        {
            // restore ptr
            arg.p_a_grid  = reinterpret_cast<ADataType>(p_a_grids[0]);
            arg.p_b_grid  = reinterpret_cast<BDataType>(p_b_grids[0]);
            arg.p_ds_grid = p_ds_grids[0];

            // free device mem
            for(size_t i = 1; i < rotating_count; i++)
            {
                hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
                hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));

                static_for<0, NumDs, 1>{}([&](auto j) {
                    using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
                    hip_check_error(
                        hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
                });
            }
        }
    }

    private:
    Argument& arg;
    std::size_t iter                       = 0;
    std::size_t rotating_count             = 1;
    std::size_t size_a                     = 0;
    std::size_t size_b                     = 0;
    std::array<std::size_t, NumDs> size_ds = {0};
    std::vector<const void*> p_a_grids;
    std::vector<const void*> p_b_grids;
    std::vector<DsGridPointer> p_ds_grids;
};

ltqin's avatar
ltqin committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
template <typename Argument>
struct RotatingMemWrapper
{
    using ADataType = decltype(Argument::p_a_grid);
    using BDataType = decltype(Argument::p_b_grid);

    RotatingMemWrapper() = delete;
    RotatingMemWrapper(Argument& arg_,
                       std::size_t rotating_count_,
                       std::size_t size_a_,
                       std::size_t size_b_)
        : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
    {
        p_a_grids.push_back(arg.p_a_grid);
        p_b_grids.push_back(arg.p_b_grid);
        for(size_t i = 1; i < rotating_count; i++)
        {
            {
                void* pADeviceBuf;
                hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
                hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
                                          const_cast<void*>(p_a_grids[0]),
                                          size_a_,
                                          hipMemcpyDeviceToDevice));
                p_a_grids.push_back(pADeviceBuf);
            }

            {
                void* pBDeviceBuf;
                hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
                hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
                                          const_cast<void*>(p_b_grids[0]),
                                          size_b_,
                                          hipMemcpyDeviceToDevice));
                p_b_grids.push_back(pBDeviceBuf);
            }
        }
    }

    void Next()
    {
        if(rotating_count > 1)
        {
            std::size_t idx = iter++ % rotating_count;
            arg.p_a_grid    = reinterpret_cast<ADataType>(p_a_grids[idx]);
            arg.p_b_grid    = reinterpret_cast<BDataType>(p_b_grids[idx]);
        }
    }
    void Print()
    {
        std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
                  << ", rotating_count: " << rotating_count << "}" << std::endl;
    }
    ~RotatingMemWrapper()
    {
        if(rotating_count > 1)
        {
            // restore ptr
            arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
            arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);

            // free device mem
            for(size_t i = 1; i < rotating_count; i++)
            {
                hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
                hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
            }
        }
    }

    private:
    Argument& arg;
    std::size_t iter           = 0;
    std::size_t rotating_count = 1;
    std::size_t size_a         = 0;
    std::size_t size_b         = 0;
    std::vector<const void*> p_a_grids;
    std::vector<const void*> p_b_grids;
};

inline void flush_icache()
{
    hipDeviceProp_t deviceProps;
    hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
    int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;

    ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
    hip_check_error(hipGetLastError());
}
// if TimePrePress == false, return time does not include preprocess's time
228
229
230
231
232
template <bool TimePreprocess,
          typename GemmArgs,
          typename... Args,
          typename F,
          typename PreProcessFunc>
ltqin's avatar
ltqin committed
233
234
235
236
237
238
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
                                             PreProcessFunc preprocess,
                                             F kernel,
                                             dim3 grid_dim,
                                             dim3 block_dim,
                                             std::size_t lds_byte,
239
240
                                             GemmArgs& gemm_args,
                                             Args... args)
ltqin's avatar
ltqin committed
241
242
{
#if CK_TIME_KERNEL
chenjun's avatar
chenjun committed
243
#define MEDIAN 0
ltqin's avatar
ltqin committed
244
245
    if(stream_config.time_kernel_)
    {
246
        if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
247
        {
248
            printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
249
250
251
252
253
254
255
256
257
258
                   __func__,
                   grid_dim.x,
                   grid_dim.y,
                   grid_dim.z,
                   block_dim.x,
                   block_dim.y,
                   block_dim.z);

            printf("Warm up %d times\n", stream_config.cold_niters_);
        }
ltqin's avatar
ltqin committed
259
260
261
        // warm up
        for(int i = 0; i < stream_config.cold_niters_; ++i)
        {
262
            kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
ltqin's avatar
ltqin committed
263
264
265
266
267
268
269
270
            hip_check_error(hipGetLastError());
        }

        const int nrepeat = stream_config.nrepeat_;
        if(nrepeat == 0)
        {
            return 0.0;
        }
271
        if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
272
273
274
        {
            printf("Start running %d times...\n", nrepeat);
        }
ltqin's avatar
ltqin committed
275
276
277
278
279
280

#if MEDIAN
        std::set<float> times;
#else
        float total_time = 0;
#endif
chenjun's avatar
chenjun committed
281
282
283
284
285
286
287
        hipEvent_t start, stop;

        hip_check_error(hipEventCreate(&start));
        hip_check_error(hipEventCreate(&stop));

        hip_check_error(hipDeviceSynchronize());
        hip_check_error(hipEventRecord(start, stream_config.stream_id_));
chenjun's avatar
chenjun committed
288

ltqin's avatar
ltqin committed
289
290
291
292
293
294
295
        for(int i = 0; i < nrepeat; ++i)
        {
            if constexpr(!TimePreprocess)
            {
                preprocess();
            }

chenjun's avatar
chenjun committed
296
            // hipEvent_t start, stop;
ltqin's avatar
ltqin committed
297

chenjun's avatar
chenjun committed
298
299
            // hip_check_error(hipEventCreate(&start));
            // hip_check_error(hipEventCreate(&stop));
ltqin's avatar
ltqin committed
300

chenjun's avatar
chenjun committed
301
302
            // hip_check_error(hipDeviceSynchronize());
            // hip_check_error(hipEventRecord(start, stream_config.stream_id_));
ltqin's avatar
ltqin committed
303
304
305
306
307
308
            // calculate preprocess time
            if constexpr(TimePreprocess)
            {
                preprocess();
            }
            // run real kernel
309
            kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
ltqin's avatar
ltqin committed
310
311
312
            hip_check_error(hipGetLastError());
            // end real kernel

chenjun's avatar
chenjun committed
313
314
315
316
317
318
319
320
321
            //             hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
            //             hip_check_error(hipEventSynchronize(stop));
            //             float cur_time = 0;
            //             hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
            // #if MEDIAN
            //             times.insert(cur_time);
            // #else
            //             total_time += cur_time;
            // #endif
ltqin's avatar
ltqin committed
322

323
            if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
324
            {
chenjun's avatar
chenjun committed
325
                // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
ltqin's avatar
ltqin committed
326

327
328
329
                printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
                       static_cast<const void*>(gemm_args.p_a_grid),
                       static_cast<const void*>(gemm_args.p_b_grid));
330
            }
ltqin's avatar
ltqin committed
331
        }
chenjun's avatar
chenjun committed
332
333
334
335
336
337
338
339
340
        hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
        hip_check_error(hipEventSynchronize(stop));
        float cur_time = 0;
        hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
        times.insert(cur_time);
#else
        total_time += cur_time;
#endif
ltqin's avatar
ltqin committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

#if MEDIAN
        auto mid = times.begin();
        std::advance(mid, (nrepeat - 1) / 2);
        if(nrepeat % 2 == 1)
        {
            return *mid;
        }
        else
        {
            auto mid_next = mid;
            std::advance(mid_next, 1);
            return (*mid + *mid_next) / 2;
        }
#else
chenjun's avatar
chenjun committed
356
        // return total_time / nrepeat;
357
358
359
360
        hipDeviceProp_t deviceProps;
        hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
        float preprocess_offset = deviceProps.multiProcessorCount==80? 0.005 : 0.01;
        return (total_time - preprocess_offset * nrepeat) / nrepeat;
ltqin's avatar
ltqin committed
361
362
363
364
365
#endif
    }
    else
    {
        preprocess();
366
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
ltqin's avatar
ltqin committed
367
368
369
370
371
        hip_check_error(hipGetLastError());

        return 0;
    }
#else
372
    kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
ltqin's avatar
ltqin committed
373
374
375
376
377
378
379
380
    hip_check_error(hipGetLastError());

    return 0;
#endif
}

} // namespace utility
} // namespace ck