common.h 6.67 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
#pragma once

#include <cstddef>
#include <cassert>
#include <cmath>
#include <iostream>
#include <fstream>
#include <sstream>
#include <memory>
#include <source_location>
#include <vector>
12
#include <list>
Zhekai Zhang's avatar
Zhekai Zhang committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <stack>
#include <map>
#include <unordered_map>
#include <set>
#include <any>
#include <variant>
#include <optional>
#include <chrono>
#include <functional>
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <spdlog/spdlog.h>

class CUDAError : public std::runtime_error {
public:
Muyang Li's avatar
Muyang Li committed
28
    CUDAError(cudaError_t errorCode, std::source_location location)
Zhekai Zhang's avatar
Zhekai Zhang committed
29
30
31
32
33
34
35
36
        : std::runtime_error(format(errorCode, location)), errorCode(errorCode), location(location) {}

public:
    const cudaError_t errorCode;
    const std::source_location location;

private:
    static std::string format(cudaError_t errorCode, std::source_location location) {
Muyang Li's avatar
Muyang Li committed
37
38
        return spdlog::fmt_lib::format(
            "CUDA error: {} (at {}:{})", cudaGetErrorString(errorCode), location.file_name(), location.line());
Zhekai Zhang's avatar
Zhekai Zhang committed
39
40
41
    }
};

Muyang Li's avatar
Muyang Li committed
42
43
inline cudaError_t checkCUDA(cudaError_t retValue,
                             const std::source_location location = std::source_location::current()) {
Zhekai Zhang's avatar
Zhekai Zhang committed
44
    if (retValue != cudaSuccess) {
45
        (void)cudaGetLastError();
Zhekai Zhang's avatar
Zhekai Zhang committed
46
47
48
49
50
        throw CUDAError(retValue, location);
    }
    return retValue;
}

Muyang Li's avatar
Muyang Li committed
51
52
inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue,
                                  const std::source_location location = std::source_location::current()) {
Zhekai Zhang's avatar
Zhekai Zhang committed
53
    if (retValue != CUBLAS_STATUS_SUCCESS) {
Muyang Li's avatar
Muyang Li committed
54
55
        throw std::runtime_error(spdlog::fmt_lib::format(
            "CUBLAS error: {} (at {}:{})", cublasGetStatusString(retValue), location.file_name(), location.line()));
Zhekai Zhang's avatar
Zhekai Zhang committed
56
57
58
59
60
61
62
63
64
65
66
67
68
    }
    return retValue;
}

inline thread_local std::stack<cudaStream_t> stackCUDAStreams;

inline cudaStream_t getCurrentCUDAStream() {
    if (stackCUDAStreams.empty()) {
        return 0;
    }
    return stackCUDAStreams.top();
}

muyangli's avatar
muyangli committed
69
70
71
72
73
74
75
struct CUDAStreamContext {
    cudaStream_t stream;

    CUDAStreamContext(cudaStream_t stream) : stream(stream) {
        stackCUDAStreams.push(stream);
    }
    CUDAStreamContext(const CUDAStreamContext &) = delete;
Muyang Li's avatar
Muyang Li committed
76
77
    CUDAStreamContext(CUDAStreamContext &&)      = delete;

muyangli's avatar
muyangli committed
78
79
80
81
82
83
84
85
86
87
88
89
90
    ~CUDAStreamContext() {
        assert(stackCUDAStreams.top() == stream);
        stackCUDAStreams.pop();
    }
};

struct CUDAStreamWrapper {
    cudaStream_t stream;

    CUDAStreamWrapper() {
        checkCUDA(cudaStreamCreate(&stream));
    }
    CUDAStreamWrapper(const CUDAStreamWrapper &) = delete;
Muyang Li's avatar
Muyang Li committed
91
    CUDAStreamWrapper(CUDAStreamWrapper &&)      = delete;
muyangli's avatar
muyangli committed
92
93
94
95
96
97
98
99
100
101
102
103
104

    ~CUDAStreamWrapper() {
        checkCUDA(cudaStreamDestroy(stream));
    }
};

struct CUDAEventWrapper {
    cudaEvent_t event;

    CUDAEventWrapper(unsigned int flags = cudaEventDefault) {
        checkCUDA(cudaEventCreateWithFlags(&event, flags));
    }
    CUDAEventWrapper(const CUDAEventWrapper &) = delete;
Muyang Li's avatar
Muyang Li committed
105
    CUDAEventWrapper(CUDAEventWrapper &&)      = delete;
muyangli's avatar
muyangli committed
106
107
108
109
110
111

    ~CUDAEventWrapper() {
        checkCUDA(cudaEventDestroy(event));
    }
};

112
113
114
115
116
117
118
119
120
121
122
123
124
/**
 * 1. hold one when entered from external code (set `device` to -1 to avoid device change)
 * 2. hold one when switching device
 * 3. hold one with `disableCache` when calling external code that may change the device
 */
class CUDADeviceContext {
public:
    CUDADeviceContext(int device = -1, bool disableCache = false) : disableCache(disableCache) {
        if (cacheDisabled()) {
            // no previous context => we might entered from external code, reset cache
            // previous context is reset on => external code may be executed, reset
            currentDeviceCache = -1;
        }
Muyang Li's avatar
Muyang Li committed
125

126
127
128
129
130
131
132
133
134
135
136
137
        ctxs.push(this);
        lastDevice = getDevice();
        if (device >= 0) {
            setDevice(device);
        }

        if (disableCache) {
            // we are about to call external code, reset cache
            currentDeviceCache = -1;
        }
    }
    CUDADeviceContext(const CUDADeviceContext &) = delete;
Muyang Li's avatar
Muyang Li committed
138
    CUDADeviceContext(CUDADeviceContext &&)      = delete;
139
140
141
142
143
144
145
146
147
148
149
150
151

    ~CUDADeviceContext() {
        if (disableCache) {
            // retured from external code, cache is not reliable, reset
            currentDeviceCache = -1;
        }

        setDevice(lastDevice);
        assert(ctxs.top() == this);
        ctxs.pop();

        if (cacheDisabled()) {
            // ctxs.empty() => we are about to return to external code, reset cache
Muyang Li's avatar
Muyang Li committed
152
153
            // otherwise => we are a nested context in a previous context with reset on, we might continue to execute
            // external code, reset
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
            currentDeviceCache = -1;
        }
    }

    const bool disableCache;
    int lastDevice;

public:
    static int getDevice() {
        int idx = -1;
        if (cacheDisabled() || currentDeviceCache < 0) {
            checkCUDA(cudaGetDevice(&idx));
        } else {
            idx = currentDeviceCache;
        }
        currentDeviceCache = cacheDisabled() ? -1 : idx;
        return idx;
    }
Muyang Li's avatar
Muyang Li committed
172

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
private:
    static void setDevice(int idx) {
        // TODO: deal with stream when switching device
        assert(idx >= 0);
        if (!cacheDisabled() && currentDeviceCache == idx) {
            return;
        }
        checkCUDA(cudaSetDevice(idx));
        currentDeviceCache = cacheDisabled() ? -1 : idx;
    }

private:
    static inline thread_local std::stack<CUDADeviceContext *> ctxs;
    static inline thread_local int currentDeviceCache = -1;

    static bool cacheDisabled() {
        return ctxs.empty() || ctxs.top()->disableCache;
    }
};

Zhekai Zhang's avatar
Zhekai Zhang committed
193
inline cudaDeviceProp *getCurrentDeviceProperties() {
194
195
196
197
198
199
200
201
202
    static thread_local std::map<int, cudaDeviceProp> props;

    int deviceId = CUDADeviceContext::getDevice();
    if (!props.contains(deviceId)) {
        cudaDeviceProp prop;
        checkCUDA(cudaGetDeviceProperties(&prop, deviceId));
        props[deviceId] = prop;
    }
    return &props.at(deviceId);
Zhekai Zhang's avatar
Zhekai Zhang committed
203
204
205
206
207
208
209
}

template<typename T>
constexpr T ceilDiv(T a, T b) {
    return (a + b - 1) / b;
}

210
211
template<typename T>
constexpr int log2Up(T value) {
Muyang Li's avatar
Muyang Li committed
212
213
214
215
216
    if (value <= 0)
        return 0;
    if (value == 1)
        return 0;
    return log2Up((value + 1) / 2) + 1;
217
218
}

Zhekai Zhang's avatar
Zhekai Zhang committed
219
220
221
222
223
224
struct CUBLASWrapper {
    cublasHandle_t handle = nullptr;

    CUBLASWrapper() {
        checkCUBLAS(cublasCreate(&handle));
    }
Muyang Li's avatar
Muyang Li committed
225
    CUBLASWrapper(CUBLASWrapper &&)       = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    CUBLASWrapper(const CUBLASWrapper &&) = delete;
    ~CUBLASWrapper() {
        if (handle) {
            checkCUBLAS(cublasDestroy(handle));
        }
    }
};

inline std::shared_ptr<CUBLASWrapper> getCUBLAS() {
    static thread_local std::weak_ptr<CUBLASWrapper> inst;
    std::shared_ptr<CUBLASWrapper> result = inst.lock();
    if (result) {
        return result;
    }
    result = std::make_shared<CUBLASWrapper>();
Muyang Li's avatar
Muyang Li committed
241
    inst   = result;
Zhekai Zhang's avatar
Zhekai Zhang committed
242
    return result;
Muyang Li's avatar
Muyang Li committed
243
}