kernel_launch.hpp 5.35 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
9
#pragma once

#include <hip/hip_runtime.h>

#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
10
#include "ck/host_utility/hip_check_error.hpp"
Chao Liu's avatar
Chao Liu committed
11
12
13
14
15
16
17
18
19
20
21
22

template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig& stream_config,
                             F kernel,
                             dim3 grid_dim,
                             dim3 block_dim,
                             std::size_t lds_byte,
                             Args... args)
{
#if CK_TIME_KERNEL
    if(stream_config.time_kernel_)
    {
23
#if DEBUG_LOG
Chao Liu's avatar
Chao Liu committed
24
25
26
27
28
29
30
31
32
        printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
               __func__,
               grid_dim.x,
               grid_dim.y,
               grid_dim.z,
               block_dim.x,
               block_dim.y,
               block_dim.z);

33
        printf("Warm up %d time\n", stream_config.cold_niters_);
34
#endif
Chao Liu's avatar
Chao Liu committed
35
        // warm up
Jing Zhang's avatar
Jing Zhang committed
36
        for(int i = 0; i < stream_config.cold_niters_; ++i)
37
38
39
40
        {
            kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
            hip_check_error(hipGetLastError());
        }
Chao Liu's avatar
Chao Liu committed
41

Jing Zhang's avatar
Jing Zhang committed
42
        const int nrepeat = stream_config.nrepeat_;
43
#if DEBUG_LOG
Chao Liu's avatar
Chao Liu committed
44
        printf("Start running %d times...\n", nrepeat);
45
#endif
Chao Liu's avatar
Chao Liu committed
46
47
48
49
50
        hipEvent_t start, stop;

        hip_check_error(hipEventCreate(&start));
        hip_check_error(hipEventCreate(&stop));

51
        std::vector<float> execution_time_series;
Chao Liu's avatar
Chao Liu committed
52

53
        for(int i = 0; i < nrepeat; ++i)
Chao Liu's avatar
Chao Liu committed
54
        {
55
56
57
            float execution_time = 0;
            hip_check_error(hipDeviceSynchronize());
            hip_check_error(hipEventRecord(start, stream_config.stream_id_));
Chao Liu's avatar
Chao Liu committed
58
            kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
59
            hip_check_error(hipGetLastError());
60
61
62
63
            hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
            hip_check_error(hipEventSynchronize(stop));
            hip_check_error(hipEventElapsedTime(&execution_time, start, stop));
            execution_time_series.push_back(execution_time);
Chao Liu's avatar
Chao Liu committed
64
65
        }

66
67
        float mean_execution_time = 0;
        float median_execution_time = 0;
Chao Liu's avatar
Chao Liu committed
68

69
70
71
72
73
74
75
#if DEBUG_LOG
        for(int i =0; i<nrepeat; i++){
           std::cout<<i<<" th launch, execution time = "<<execution_time_series[i]<<" ms"<<std::endl;
        }
#endif

        std::sort(execution_time_series.begin(),execution_time_series.end());
Chao Liu's avatar
Chao Liu committed
76

77
78
        mean_execution_time = std::reduce(execution_time_series.begin(), execution_time_series.end(), .0)/static_cast<float>(nrepeat);
        median_execution_time = execution_time_series[execution_time_series.size()/2];
Chao Liu's avatar
Chao Liu committed
79

80
81
82
83
        if(stream_config.time_kernel_==1)
        return mean_execution_time;
        else
        return median_execution_time;
Chao Liu's avatar
Chao Liu committed
84
85
86
87
    }
    else
    {
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
88
        hip_check_error(hipGetLastError());
Chao Liu's avatar
Chao Liu committed
89
90
91
92
93

        return 0;
    }
#else
    kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
94
    hip_check_error(hipGetLastError());
Chao Liu's avatar
Chao Liu committed
95
96
97
98

    return 0;
#endif
}
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

template <typename... Args, typename F, typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
                                             PreProcessFunc preprocess,
                                             F kernel,
                                             dim3 grid_dim,
                                             dim3 block_dim,
                                             std::size_t lds_byte,
                                             Args... args)
{
#if CK_TIME_KERNEL
    if(stream_config.time_kernel_)
    {
#if DEBUG_LOG
        printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
               __func__,
               grid_dim.x,
               grid_dim.y,
               grid_dim.z,
               block_dim.x,
               block_dim.y,
               block_dim.z);

        printf("Warm up 1 time\n");
#endif
        // warm up
        preprocess();
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
127
        hip_check_error(hipGetLastError());
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        const int nrepeat = 10;
#if DEBUG_LOG
        printf("Start running %d times...\n", nrepeat);
#endif
        hipEvent_t start, stop;

        hip_check_error(hipEventCreate(&start));
        hip_check_error(hipEventCreate(&stop));

        hip_check_error(hipDeviceSynchronize());
        hip_check_error(hipEventRecord(start, stream_config.stream_id_));

        for(int i = 0; i < nrepeat; ++i)
        {
            preprocess();
            kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
145
            hip_check_error(hipGetLastError());
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        }

        hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
        hip_check_error(hipEventSynchronize(stop));

        float total_time = 0;

        hip_check_error(hipEventElapsedTime(&total_time, start, stop));

        return total_time / nrepeat;
    }
    else
    {
        preprocess();
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
161
        hip_check_error(hipGetLastError());
162
163
164
165
166

        return 0;
    }
#else
    kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
167
    hip_check_error(hipGetLastError());
168
169
170
171

    return 0;
#endif
}