kernel_launch.hpp 6.71 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
9
#pragma once

#include <hip/hip_runtime.h>

#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
10
#include "ck/host_utility/hip_check_error.hpp"
Chao Liu's avatar
Chao Liu committed
11

12
#ifndef KERNARG_PRELOAD
Chao Liu's avatar
Chao Liu committed
13
template <typename... Args, typename F>
Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
14
float launch_and_time_kernel(const StreamConfig& stream_config,
Chao Liu's avatar
Chao Liu committed
15
16
17
18
19
20
21
                             F kernel,
                             dim3 grid_dim,
                             dim3 block_dim,
                             std::size_t lds_byte,
                             Args... args)
{
#if CK_TIME_KERNEL
Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
22
    if(stream_config.time_kernel_)
Chao Liu's avatar
Chao Liu committed
23
    {
24
#if DEBUG_LOG
Chao Liu's avatar
Chao Liu committed
25
26
27
28
29
30
31
32
33
34
        printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
               __func__,
               grid_dim.x,
               grid_dim.y,
               grid_dim.z,
               block_dim.x,
               block_dim.y,
               block_dim.z);

        printf("Warm up 1 time\n");
35
#endif
Chao Liu's avatar
Chao Liu committed
36
37
        // warm up
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
38
        hip_check_error(hipGetLastError());
Chao Liu's avatar
Chao Liu committed
39

40
41
        const int nrepeat = 10;
#if DEBUG_LOG
Chao Liu's avatar
Chao Liu committed
42
        printf("Start running %d times...\n", nrepeat);
43
#endif
Chao Liu's avatar
Chao Liu committed
44
45
46
47
48
49
50
51
        hipEvent_t start, stop;

        hip_check_error(hipEventCreate(&start));
        hip_check_error(hipEventCreate(&stop));

        hip_check_error(hipDeviceSynchronize());
        hip_check_error(hipEventRecord(start, stream_config.stream_id_));

Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
52
        for(int i = 0; i < nrepeat; ++i)
Chao Liu's avatar
Chao Liu committed
53
54
        {
            kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
55
            hip_check_error(hipGetLastError());
Chao Liu's avatar
Chao Liu committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        }

        hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
        hip_check_error(hipEventSynchronize(stop));

        float total_time = 0;

        hip_check_error(hipEventElapsedTime(&total_time, start, stop));

        return total_time / nrepeat;
    }
    else
    {
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
70
        hip_check_error(hipGetLastError());
Chao Liu's avatar
Chao Liu committed
71
72
73
74
75

        return 0;
    }
#else
    kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
76
    hip_check_error(hipGetLastError());
Chao Liu's avatar
Chao Liu committed
77
78
79
80

    return 0;
#endif
}
81

82
83
#else
template <typename... Args, typename F>
Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
84
float launch_and_time_kernel(const StreamConfig& stream_config,
85
86
87
88
89
90
91
92
93
94
                             F kernel,
                             dim3 grid_dim,
                             dim3 block_dim,
                             std::size_t lds_byte,
                             Args... args)
{
    // Args* args1;
    // hipGetErrorString(hipMalloc(&args1, sizeof(Args)));
    // hip_check_error(hipMemcpy(args1, &args, sizeof(Args), hipMemcpyHostToDevice));
#if CK_TIME_KERNEL
Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
95
    if(stream_config.time_kernel_)
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    {
#if DEBUG_LOG
        printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
               __func__,
               grid_dim.x,
               grid_dim.y,
               grid_dim.z,
               block_dim.x,
               block_dim.y,
               block_dim.z);

        printf("Warm up 1 time\n");
#endif
        //
        // warm up
        const int nrepeat = 1000;
Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
112
113
114
        for(auto i = 0; i < nrepeat; i++)
            hipLaunchKernelGGL(
                kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        hip_check_error(hipGetLastError());

#if DEBUG_LOG
        printf("Start running %d times...\n", nrepeat);
#endif
        hipEvent_t start, stop;
        float total_time = 0;

        hip_check_error(hipEventCreate(&start));
        hip_check_error(hipEventCreate(&stop));

        hip_check_error(hipDeviceSynchronize());

        hip_check_error(hipEventRecord(start, stream_config.stream_id_));

Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
130
131
132
        for(int i = 0; i < nrepeat; ++i)
            hipLaunchKernelGGL(
                kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
133
134
135
136
137
138
139
140
141
142
        // hip_check_error(hipGetLastError());

        hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
        hip_check_error(hipEventSynchronize(stop));

        hip_check_error(hipEventElapsedTime(&total_time, start, stop));
        return total_time / nrepeat;
    }
    else
    {
Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
143
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
144
145
146
147
148
149
150
151
152
153
154
155
        hip_check_error(hipGetLastError());

        return 0;
    }
#else
    kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
    hip_check_error(hipGetLastError());

    return 0;
#endif
}
#endif
156
template <typename... Args, typename F, typename PreProcessFunc>
Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
157
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
158
159
160
161
162
163
164
165
                                             PreProcessFunc preprocess,
                                             F kernel,
                                             dim3 grid_dim,
                                             dim3 block_dim,
                                             std::size_t lds_byte,
                                             Args... args)
{
#if CK_TIME_KERNEL
Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
166
    if(stream_config.time_kernel_)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    {
#if DEBUG_LOG
        printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
               __func__,
               grid_dim.x,
               grid_dim.y,
               grid_dim.z,
               block_dim.x,
               block_dim.y,
               block_dim.z);

        printf("Warm up 1 time\n");
#endif
        // warm up
        preprocess();
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
183
        hip_check_error(hipGetLastError());
184
185
186
187
188
189
190
191
192
193
194
195
196

        const int nrepeat = 10;
#if DEBUG_LOG
        printf("Start running %d times...\n", nrepeat);
#endif
        hipEvent_t start, stop;

        hip_check_error(hipEventCreate(&start));
        hip_check_error(hipEventCreate(&stop));

        hip_check_error(hipDeviceSynchronize());
        hip_check_error(hipEventRecord(start, stream_config.stream_id_));

Harisankar Sadasivan's avatar
Harisankar Sadasivan committed
197
        for(int i = 0; i < nrepeat; ++i)
198
199
200
        {
            preprocess();
            kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
201
            hip_check_error(hipGetLastError());
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        }

        hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
        hip_check_error(hipEventSynchronize(stop));

        float total_time = 0;

        hip_check_error(hipEventElapsedTime(&total_time, start, stop));

        return total_time / nrepeat;
    }
    else
    {
        preprocess();
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
217
        hip_check_error(hipGetLastError());
218
219
220
221
222

        return 0;
    }
#else
    kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
223
    hip_check_error(hipGetLastError());
224
225
226
227

    return 0;
#endif
}