cuda_common.h 5.9 KB
Newer Older
1
/**
2
 *  Copyright (c) 2017 by Contributors
3
4
 * @file cuda_common.h
 * @brief Common utilities for CUDA
5
6
7
8
9
10
 */
#ifndef DGL_RUNTIME_CUDA_CUDA_COMMON_H_
#define DGL_RUNTIME_CUDA_CUDA_COMMON_H_

#include <cublas_v2.h>
#include <cuda_runtime.h>
11
#include <curand.h>
12
#include <cusparse.h>
13
#include <dgl/runtime/packed_func.h>
14

15
#include <string>
16

17
18
19
20
21
#include "../workspace_pool.h"

namespace dgl {
namespace runtime {

22
23
template <typename T>
inline bool is_zero(T size) {
24
  return size == 0;
25
26
27
28
}

template <>
inline bool is_zero<dim3>(dim3 size) {
29
  return size.x == 0 || size.y == 0 || size.z == 0;
30
31
}

32
33
34
35
#define CUDA_DRIVER_CALL(x)                                             \
  {                                                                     \
    CUresult result = x;                                                \
    if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
36
      const char* msg;                                                  \
37
      cuGetErrorName(result, &msg);                                     \
38
      LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg;     \
39
40
41
    }                                                                   \
  }

42
43
44
45
46
#define CUDA_CALL(func)                                      \
  {                                                          \
    cudaError_t e = (func);                                  \
    CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
        << "CUDA: " << cudaGetErrorString(e);                \
47
48
  }

49
50
51
52
53
54
55
56
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...)            \
  {                                                                           \
    if (!dgl::runtime::is_zero((nblks)) && !dgl::runtime::is_zero((nthrs))) { \
      (kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__);         \
      cudaError_t e = cudaGetLastError();                                     \
      CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading)                \
          << "CUDA kernel launch error: " << cudaGetErrorString(e);           \
    }                                                                         \
57
58
  }

59
60
61
62
#define CUSPARSE_CALL(func)                                         \
  {                                                                 \
    cusparseStatus_t e = (func);                                    \
    CHECK(e == CUSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR: " << e; \
63
64
  }

65
66
67
68
#define CUBLAS_CALL(func)                                       \
  {                                                             \
    cublasStatus_t e = (func);                                  \
    CHECK(e == CUBLAS_STATUS_SUCCESS) << "CUBLAS ERROR: " << e; \
69
70
  }

71
72
73
74
75
76
77
#define CURAND_CALL(func)                                                      \
  {                                                                            \
    curandStatus_t e = (func);                                                 \
    CHECK(e == CURAND_STATUS_SUCCESS)                                          \
        << "CURAND Error: " << dgl::runtime::curandGetErrorString(e) << " at " \
        << __FILE__ << ":" << __LINE__;                                        \
  }
78
79
80

inline const char* curandGetErrorString(curandStatus_t error) {
  switch (error) {
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    case CURAND_STATUS_SUCCESS:
      return "CURAND_STATUS_SUCCESS";
    case CURAND_STATUS_VERSION_MISMATCH:
      return "CURAND_STATUS_VERSION_MISMATCH";
    case CURAND_STATUS_NOT_INITIALIZED:
      return "CURAND_STATUS_NOT_INITIALIZED";
    case CURAND_STATUS_ALLOCATION_FAILED:
      return "CURAND_STATUS_ALLOCATION_FAILED";
    case CURAND_STATUS_TYPE_ERROR:
      return "CURAND_STATUS_TYPE_ERROR";
    case CURAND_STATUS_OUT_OF_RANGE:
      return "CURAND_STATUS_OUT_OF_RANGE";
    case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
      return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
    case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
      return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
    case CURAND_STATUS_LAUNCH_FAILURE:
      return "CURAND_STATUS_LAUNCH_FAILURE";
    case CURAND_STATUS_PREEXISTING_FAILURE:
      return "CURAND_STATUS_PREEXISTING_FAILURE";
    case CURAND_STATUS_INITIALIZATION_FAILED:
      return "CURAND_STATUS_INITIALIZATION_FAILED";
    case CURAND_STATUS_ARCH_MISMATCH:
      return "CURAND_STATUS_ARCH_MISMATCH";
    case CURAND_STATUS_INTERNAL_ERROR:
      return "CURAND_STATUS_INTERNAL_ERROR";
107
108
109
110
111
  }
  // To suppress compiler warning.
  return "Unrecognized curand error string";
}

112
/**
113
 * @brief Cast data type to cudaDataType_t.
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
 */
template <typename T>
struct cuda_dtype {
  static constexpr cudaDataType_t value = CUDA_R_32F;
};

template <>
struct cuda_dtype<half> {
  static constexpr cudaDataType_t value = CUDA_R_16F;
};

template <>
struct cuda_dtype<float> {
  static constexpr cudaDataType_t value = CUDA_R_32F;
};

template <>
struct cuda_dtype<double> {
  static constexpr cudaDataType_t value = CUDA_R_64F;
};

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
135
#if CUDART_VERSION >= 11000
136
/**
137
 * @brief Cast index data type to cusparseIndexType_t.
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
 */
template <typename T>
struct cusparse_idtype {
  static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_32I;
};

template <>
struct cusparse_idtype<int32_t> {
  static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_32I;
};

template <>
struct cusparse_idtype<int64_t> {
  static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_64I;
};
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
153
#endif
154

155
/** @brief Thread local workspace */
156
157
class CUDAThreadEntry {
 public:
158
  /** @brief The cusparse handler */
159
  cusparseHandle_t cusparse_handle{nullptr};
160
  /** @brief The cublas handler */
161
  cublasHandle_t cublas_handle{nullptr};
162
  /** @brief The curand generator */
163
  curandGenerator_t curand_gen{nullptr};
164
  /** @brief thread local pool*/
165
  WorkspacePool pool;
166
  /** @brief constructor */
167
168
169
170
  CUDAThreadEntry();
  // get the threadlocal workspace
  static CUDAThreadEntry* ThreadLocal();
};
171

172
/** @brief Get the current CUDA stream */
173
cudaStream_t getCurrentCUDAStream();
174
175
176
}  // namespace runtime
}  // namespace dgl
#endif  // DGL_RUNTIME_CUDA_CUDA_COMMON_H_