aiter_hip_common.h 5.88 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
// SPDX-License-Identifier: MIT
 
#pragma once
4
#include "aiter_enum.h"
Xiaowei.zhang's avatar
Xiaowei.zhang committed
5
6
7
#include "ck_tile/core.hpp"
#include <hip/hip_runtime.h>
#include <iostream>
8
9
10
#include <sstream>
#include <stdexcept>
#include <utility>
Xiaowei.zhang's avatar
Xiaowei.zhang committed
11
12
13
14
15
16
17
18

enum class GPUArch
{
    gfx936,
    gfx938,
    gfx946,
};

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
namespace aiter_detail {

inline thread_local bool g_aiter_can_throw = false;

// Fatal (non-recoverable) error handler — used by HIP_CALL.
// Always aborts; does not consult g_aiter_can_throw.
template <typename... Args>
[[noreturn, gnu::noinline]] inline void aiter_check_fatal(const char* file, size_t line, Args&&... args)
{
    std::cerr << "[AITER] " << file << ":" << line << " ";
    (std::cerr << ... << std::forward<Args>(args)) << std::endl;
    std::abort();
}

template <typename... Args>
[[noreturn]] inline void check_fail(const char* file, int line, Args&&... args)
{
    std::ostringstream oss;
    oss << "[AITER] " << file << ":" << line;
    if constexpr(sizeof...(Args) > 0)
    {
        oss << " ";
        (oss << ... << std::forward<Args>(args));
    }
    else
    {
        oss << " check failed";
    }
    std::string msg = oss.str();
    std::cerr << msg << std::endl;
    if(g_aiter_can_throw)
    {
        throw std::runtime_error(std::move(msg));
    }
    std::abort();
}
} // namespace aiter_detail

#define AITER_CHECK(x, ...)                                                          \
    do                                                                               \
    {                                                                                \
        if(!(x)) [[unlikely]]                                                        \
        {                                                                            \
            aiter_detail::check_fail(__FILE__, __LINE__ __VA_OPT__(, ) __VA_ARGS__); \
        }                                                                            \
    } while(0)

#define HIP_CALL(call)                                                                \
    do                                                                                \
    {                                                                                 \
        hipError_t err = call;                                                        \
        if(err != hipSuccess) [[unlikely]]                                            \
        {                                                                             \
            aiter_detail::aiter_check_fatal(__FILE__,                                 \
                                            __LINE__,                                 \
                                            "fail to call " #call " ---> [HIP error](", \
                                            hipGetErrorString(err),                   \
                                            ')');                                     \
        }                                                                             \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    } while(0)

struct p3
{
    unsigned int _p0;
    unsigned int _p1;
    unsigned int _p2;
};
struct p2
{
    unsigned int _p0;
    unsigned int _p1;
};
struct p1
{
    unsigned int _p0;
};

struct AiterAsmKernelArgs
{
    void *args_ptr;
    void *arg_size_ptr;
    int gdx;
    int gdy;
    int gdz;
    int bdx;
    int bdy;
    int bdz;
    const hipStream_t stream;
};

class AiterAsmKernel
{
private:
    hipModule_t module;
    hipFunction_t kernel_func;

public:
    AiterAsmKernel(const char *name, const char *hsaco)
    {
        const char *AITER_ASM_DIR = std::getenv("AITER_ASM_DIR");
        std::cout << "[aiter] hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() << " GetFunction: " << name;
        HIP_CALL(hipModuleLoad(&module, (std::string(AITER_ASM_DIR) + hsaco).c_str()));
        HIP_CALL(hipModuleGetFunction(&kernel_func, module, name));
        std::cout << " Success" << std::endl;
    };

    ~AiterAsmKernel()
    {
        HIP_CALL(hipModuleUnload(module));
    }

    void launch_kernel(const AiterAsmKernelArgs &kargs)
    {
        void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, kargs.args_ptr,
                          HIP_LAUNCH_PARAM_BUFFER_SIZE, kargs.arg_size_ptr,
                          HIP_LAUNCH_PARAM_END};

        HIP_CALL(hipModuleLaunchKernel(kernel_func,
                                       kargs.gdx, kargs.gdy, kargs.gdz,
                                       kargs.bdx, kargs.bdy, kargs.bdz,
                                       0, kargs.stream, nullptr, (void **)&config));
    };
};

static const std::string get_gpu_arch()
{
    int device_count;
    hipError_t err = hipGetDeviceCount(&device_count);
    if(err != hipSuccess || device_count == 0)
    {
        return "No GPU Found";
    }

    hipDeviceProp_t prop;
    hipGetDeviceProperties(&prop, 0);

    std::string arch_full = prop.gcnArchName;
    size_t colon_pos      = arch_full.find(':');
    if(colon_pos != std::string::npos)
    {
        return arch_full.substr(0, colon_pos);
    }
    else
    {
        return arch_full;
    }
}

static const uint32_t get_num_cu_func()
{
    auto get_num_cu_local = []() {
        hipDevice_t dev;
        hipDeviceProp_t dev_prop;
        HIP_CALL(hipGetDevice(&dev));
        HIP_CALL(hipGetDeviceProperties(&dev_prop, dev));
        return dev_prop.multiProcessorCount;
    };
    static const uint32_t num_cu = get_num_cu_local();
    return num_cu;
}
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

/// RAII guard that saves the current HIP device and restores it on destruction.
/// Required by AiterTensor factory methods and any code that temporarily switches devices.
class HipDeviceGuard
{
public:
    explicit HipDeviceGuard(int device_id)
    {
        HIP_CALL(hipGetDevice(&prev_device_));
        HIP_CALL(hipSetDevice(device_id));
    }
    ~HipDeviceGuard() noexcept { HIP_CALL(hipSetDevice(prev_device_)); }
    HipDeviceGuard(const HipDeviceGuard&)            = delete;
    HipDeviceGuard& operator=(const HipDeviceGuard&) = delete;

private:
    int prev_device_{};
};