utils.h 8.57 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/utils.h
 * @brief Utilities for CUDA kernels.
5
 */
6
7
#ifndef DGL_ARRAY_CUDA_UTILS_H_
#define DGL_ARRAY_CUDA_UTILS_H_
8

9
#include <dgl/runtime/c_runtime_api.h>
10
11
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/ndarray.h>
12
13
#include <dmlc/logging.h>

14
#include <cub/cub.cuh>
15
16
#include <type_traits>

17
#include "../../runtime/cuda/cuda_common.h"
18
19
20
21
22
23
24

namespace dgl {
namespace cuda {

#define CUDA_MAX_NUM_BLOCKS_X 0x7FFFFFFF
#define CUDA_MAX_NUM_BLOCKS_Y 0xFFFF
#define CUDA_MAX_NUM_BLOCKS_Z 0xFFFF
25
26
// The max number of threads per block
#define CUDA_MAX_NUM_THREADS 256
27

28
/** @brief Calculate the number of threads needed given the dimension length.
29
30
31
32
33
 *
 * It finds the biggest number that is smaller than min(dim, max_nthrs)
 * and is also power of two.
 */
inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) {
34
  CHECK_GE(dim, 0);
35
  if (dim == 0) return 1;
36
37
38
39
40
41
42
  int ret = max_nthrs;
  while (ret > dim) {
    ret = ret >> 1;
  }
  return ret;
}

43
44
45
46
47
48
49
50
template <typename T>
int _NumberOfBits(const T& range) {
  if (range <= 1) {
    // ranges of 0 or 1 require no bits to store
    return 0;
  }

  int bits = 1;
51
52
  const auto urange = static_cast<std::make_unsigned_t<T>>(range);
  while (bits < static_cast<int>(sizeof(T) * 8) && (1ull << bits) < urange) {
53
54
55
    ++bits;
  }

56
57
58
  if (bits < static_cast<int>(sizeof(T) * 8)) {
    CHECK_EQ((range - 1) >> bits, 0);
  }
59
60
61
62
63
  CHECK_NE((range - 1) >> (bits - 1), 0);

  return bits;
}

64
/**
65
 * @brief Find number of blocks is smaller than nblks and max_nblks
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
 * on the given axis ('x', 'y' or 'z').
 */
template <char axis>
inline int FindNumBlocks(int nblks, int max_nblks = -1) {
  int default_max_nblks = -1;
  switch (axis) {
    case 'x':
      default_max_nblks = CUDA_MAX_NUM_BLOCKS_X;
      break;
    case 'y':
      default_max_nblks = CUDA_MAX_NUM_BLOCKS_Y;
      break;
    case 'z':
      default_max_nblks = CUDA_MAX_NUM_BLOCKS_Z;
      break;
    default:
      LOG(FATAL) << "Axis " << axis << " not recognized";
      break;
  }
85
  if (max_nblks == -1) max_nblks = default_max_nblks;
86
  CHECK_NE(nblks, 0);
87
  if (nblks < max_nblks) return nblks;
88
89
90
91
92
93
94
95
96
97
98
99
  return max_nblks;
}

template <typename T>
__device__ __forceinline__ T _ldg(T* addr) {
#if __CUDA_ARCH__ >= 350
  return __ldg(addr);
#else
  return *addr;
#endif
}

100
/**
101
 * @brief Return true if the given bool flag array is all true.
102
103
 * The input bool array is in int8_t type so it is aligned with byte address.
 *
104
105
106
107
 * @param flags The bool array.
 * @param length The length.
 * @param ctx Device context.
 * @return True if all the flags are true.
108
 */
109
bool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx);
110

111
/**
112
 * @brief CUDA Kernel of filling the vector started from ptr of size length
113
 *        with val.
114
 * @note internal use only.
115
116
117
118
119
120
121
122
123
124
125
 */
template <typename DType>
__global__ void _FillKernel(DType* ptr, size_t length, DType val) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    ptr[tx] = val;
    tx += stride_x;
  }
}

126
/** @brief Fill the vector started from ptr of size length with val */
127
128
template <typename DType>
void _Fill(DType* ptr, size_t length, DType val) {
129
  cudaStream_t stream = runtime::getCurrentCUDAStream();
130
  int nt = FindNumThreads(length);
131
132
  int nb =
      (length + nt - 1) / nt;  // on x-axis, no need to worry about upperbound.
133
  CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, stream, ptr, length, val);
134
135
}

136
/**
137
 * @brief Search adjacency list linearly for each (row, col) pair and
138
139
140
141
142
143
144
145
146
 * write the data under the matched position in the indices array to the output.
 *
 * If there is no match, the value in \c filler is written.
 * If there are multiple matches, only the first match is written.
 * If the given data array is null, write the matched position to the output.
 */
template <typename IdType, typename DType>
__global__ void _LinearSearchKernel(
    const IdType* indptr, const IdType* indices, const IdType* data,
147
148
149
    const IdType* row, const IdType* col, int64_t row_stride,
    int64_t col_stride, int64_t length, const DType* weights, DType filler,
    DType* out) {
150
151
152
153
154
155
156
157
158
159
160
161
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    int rpos = tx * row_stride, cpos = tx * col_stride;
    IdType v = -1;
    const IdType r = row[rpos], c = col[cpos];
    for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
      if (indices[i] == c) {
        v = data ? data[i] : i;
        break;
      }
    }
162
    if (v == -1) {
163
      out[tx] = filler;
164
165
166
167
168
169
170
171
    } else {
      // The casts here are to be able to handle DType being __half.
      // GCC treats int64_t as a distinct type from long long, so
      // without the explcit cast to long long, it errors out saying
      // that the implicit cast results in an ambiguous choice of
      // constructor for __half.
      // The using statement is to avoid a linter error about using
      // long or long long.
172
      using LongLong = long long;  // NOLINT
173
174
      out[tx] = weights ? weights[v] : DType(LongLong(v));
    }
175
176
177
178
    tx += stride_x;
  }
}

179
#if BF16_ENABLED
180
/**
181
 * @brief Specialization for bf16 because conversion from long long to bfloat16
182
183
184
185
186
 * doesn't exist before SM80.
 */
template <typename IdType>
__global__ void _LinearSearchKernel(
    const IdType* indptr, const IdType* indices, const IdType* data,
187
188
189
    const IdType* row, const IdType* col, int64_t row_stride,
    int64_t col_stride, int64_t length, const __nv_bfloat16* weights,
    __nv_bfloat16 filler, __nv_bfloat16* out) {
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    int rpos = tx * row_stride, cpos = tx * col_stride;
    IdType v = -1;
    const IdType r = row[rpos], c = col[cpos];
    for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
      if (indices[i] == c) {
        v = data ? data[i] : i;
        break;
      }
    }
    if (v == -1) {
      out[tx] = filler;
    } else {
205
206
      // If the result is saved in bf16, it should be fine to convert it to
      // float first
207
208
209
210
211
212
213
      out[tx] = weights ? weights[v] : __nv_bfloat16(static_cast<float>(v));
    }
    tx += stride_x;
  }
}
#endif  // BF16_ENABLED

214
215
template <typename DType>
inline DType GetCUDAScalar(
216
    runtime::DeviceAPI* device_api, DGLContext ctx, const DType* cuda_ptr) {
217
218
  DType result;
  device_api->CopyDataFromTo(
219
      cuda_ptr, 0, &result, 0, sizeof(result), ctx, DGLContext{kDGLCPU, 0},
220
      DGLDataTypeTraits<DType>::dtype);
221
222
223
  return result;
}

224
/**
225
 * @brief Given a sorted array and a value this function returns the index
226
227
228
229
230
231
232
233
234
235
 * of the first element which compares greater than value.
 *
 * This function assumes 0-based index
 * @param A: ascending sorted array
 * @param n: size of the A
 * @param x: value to search in A
 * @return index, i, of the first element st. A[i]>x. If x>=A[n-1] returns n.
 * if x<A[0] then it returns 0.
 */
template <typename IdType>
236
__device__ IdType _UpperBound(const IdType* A, int64_t n, IdType x) {
237
238
  IdType l = 0, r = n, m = 0;
  while (l < r) {
239
    m = l + (r - l) / 2;
240
    if (x >= A[m]) {
241
      l = m + 1;
242
243
244
245
246
247
248
    } else {
      r = m;
    }
  }
  return l;
}

249
/**
250
 * @brief Given a sorted array and a value this function returns the index
251
252
253
254
255
256
257
258
259
 * of the element who is equal to val. If not exist returns n+1
 *
 * This function assumes 0-based index
 * @param A: ascending sorted array
 * @param n: size of the A
 * @param x: value to search in A
 * @return index, i, st. A[i]==x. If such an index not exists returns 'n'.
 */
template <typename IdType>
260
261
__device__ IdType _BinarySearch(const IdType* A, int64_t n, IdType x) {
  IdType l = 0, r = n - 1, m = 0;
262
  while (l <= r) {
263
    m = l + (r - l) / 2;
264
265
266
267
    if (A[m] == x) {
      return m;
    }
    if (A[m] < x) {
268
      l = m + 1;
269
    } else {
270
      r = m - 1;
271
272
273
274
275
    }
  }
  return n;  // not found
}

276
277
template <typename DType, typename BoolType>
void MaskSelect(
278
279
280
    runtime::DeviceAPI* device, const DGLContext& ctx, const DType* input,
    const BoolType* mask, DType* output, int64_t n, int64_t* rst,
    cudaStream_t stream) {
281
282
283
284
285
286
287
288
289
  size_t workspace_size = 0;
  CUDA_CALL(cub::DeviceSelect::Flagged(
      nullptr, workspace_size, input, mask, output, rst, n, stream));
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
  CUDA_CALL(cub::DeviceSelect::Flagged(
      workspace, workspace_size, input, mask, output, rst, n, stream));
  device->FreeWorkspace(ctx, workspace);
}

290
291
292
293
294
295
296
297
inline void* GetDevicePointer(runtime::NDArray array) {
  void* ptr = array->data;
  if (array.IsPinned()) {
    CUDA_CALL(cudaHostGetDevicePointer(&ptr, ptr, 0));
  }
  return ptr;
}

298
299
300
}  // namespace cuda
}  // namespace dgl

301
#endif  // DGL_ARRAY_CUDA_UTILS_H_