kernel_launch.hpp 2.73 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
5
#pragma once

carlushuang's avatar
carlushuang committed
6
7
8
#include <chrono>

#ifndef CK_NOGPU
Chao Liu's avatar
Chao Liu committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <hip/hip_runtime.h>

#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/device_utility/hip_check_error.hpp"

template <typename... Args, typename F>
float launch_and_time_kernel(const StreamConfig& stream_config,
                             F kernel,
                             dim3 grid_dim,
                             dim3 block_dim,
                             std::size_t lds_byte,
                             Args... args)
{
#if CK_TIME_KERNEL
    if(stream_config.time_kernel_)
    {
        printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
               __func__,
               grid_dim.x,
               grid_dim.y,
               grid_dim.z,
               block_dim.x,
               block_dim.y,
               block_dim.z);

        const int nrepeat = 10;

        printf("Warm up 1 time\n");

        // warm up
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);

        printf("Start running %d times...\n", nrepeat);

        hipEvent_t start, stop;

        hip_check_error(hipEventCreate(&start));
        hip_check_error(hipEventCreate(&stop));

        hip_check_error(hipDeviceSynchronize());
        hip_check_error(hipEventRecord(start, stream_config.stream_id_));

        for(int i = 0; i < nrepeat; ++i)
        {
            kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
        }

        hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
        hip_check_error(hipEventSynchronize(stop));

        float total_time = 0;

        hip_check_error(hipEventElapsedTime(&total_time, start, stop));

        return total_time / nrepeat;
    }
    else
    {
        kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);

        return 0;
    }
#else
    kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);

    return 0;
#endif
}
carlushuang's avatar
carlushuang committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#endif

template <typename... Args, typename F>
void launch_cpu_kernel(F kernel, Args... args)
{
    kernel(args...);
}

template <typename... Args, typename F>
float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args)
{
    int nwarmup = 3;

    for(int i = 0; i < nwarmup; i++)
        kernel(args...);

    auto mStart = std::chrono::high_resolution_clock::now();
    for(int i = 0; i < nrepeat; i++)
    {
        kernel(args...);
    }
    auto mStop = std::chrono::high_resolution_clock::now();

    float ms = static_cast<float>(
                   std::chrono::duration_cast<std::chrono::microseconds>(mStop - mStart).count()) *
               1e-3;

    return ms / nrepeat;
}