spmat_op_impl_csr.cu 23.1 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/spmat_op_impl_csr.cu
 * @brief CSR operator CPU implementation
5
6
 */
#include <dgl/array.h>
7
8
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
9

10
#include <numeric>
11
12
13
#include <unordered_set>
#include <vector>

14
#include "../../runtime/cuda/cuda_common.h"
15
16
#include "./atomic.cuh"
#include "./dgl_cub.cuh"
17
#include "./utils.h"
18
19
20
21

namespace dgl {

using runtime::NDArray;
22
using namespace cuda;
23
24
25
26
27
28

namespace aten {
namespace impl {

///////////////////////////// CSRIsNonZero /////////////////////////////

29
template <DGLDeviceType XPU, typename IdType>
30
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
31
  cudaStream_t stream = runtime::getCurrentCUDAStream();
32
33
34
35
36
37
38
39
  const auto& ctx = csr.indptr->ctx;
  IdArray rows = aten::VecToIdArray<int64_t>({row}, sizeof(IdType) * 8, ctx);
  IdArray cols = aten::VecToIdArray<int64_t>({col}, sizeof(IdType) * 8, ctx);
  rows = rows.CopyTo(ctx);
  cols = cols.CopyTo(ctx);
  IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8);
  const IdType* data = nullptr;
  // TODO(minjie): use binary search for sorted csr
40
41
42
43
44
  CUDA_KERNEL_CALL(
      dgl::cuda::_LinearSearchKernel, 1, 1, 0, stream, csr.indptr.Ptr<IdType>(),
      csr.indices.Ptr<IdType>(), data, rows.Ptr<IdType>(), cols.Ptr<IdType>(),
      1, 1, 1, static_cast<IdType*>(nullptr), static_cast<IdType>(-1),
      out.Ptr<IdType>());
45
  out = out.CopyTo(DGLContext{kDGLCPU, 0});
46
47
48
  return *out.Ptr<IdType>() != -1;
}

49
50
template bool CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
51

52
template <DGLDeviceType XPU, typename IdType>
53
54
55
56
57
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto rstlen = std::max(rowlen, collen);
  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
58
  if (rstlen == 0) return rst;
59
60
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
61
  cudaStream_t stream = runtime::getCurrentCUDAStream();
62
  const int nt = dgl::cuda::FindNumThreads(rstlen);
63
64
  const int nb = (rstlen + nt - 1) / nt;
  const IdType* data = nullptr;
65
66
67
68
  const IdType* indptr_data =
      static_cast<IdType*>(GetDevicePointer(csr.indptr));
  const IdType* indices_data =
      static_cast<IdType*>(GetDevicePointer(csr.indices));
69
  // TODO(minjie): use binary search for sorted csr
70
71
72
73
74
  CUDA_KERNEL_CALL(
      dgl::cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data,
      indices_data, data, row.Ptr<IdType>(), col.Ptr<IdType>(), row_stride,
      col_stride, rstlen, static_cast<IdType*>(nullptr),
      static_cast<IdType>(-1), rst.Ptr<IdType>());
75
76
77
  return rst != -1;
}

78
79
template NDArray CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, NDArray, NDArray);
80
81
82

///////////////////////////// CSRHasDuplicate /////////////////////////////

83
/**
84
 * @brief Check whether each row does not have any duplicate entries.
85
86
87
88
 * Assume the CSR is sorted.
 */
template <typename IdType>
__global__ void _SegmentHasNoDuplicate(
89
90
    const IdType* indptr, const IdType* indices, int64_t num_rows,
    int8_t* flags) {
91
92
93
94
95
96
97
98
99
100
101
102
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_rows) {
    bool f = true;
    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {
      f = (indices[i - 1] != indices[i]);
    }
    flags[tx] = static_cast<int8_t>(f);
    tx += stride_x;
  }
}

103
template <DGLDeviceType XPU, typename IdType>
104
bool CSRHasDuplicate(CSRMatrix csr) {
105
  if (!csr.sorted) csr = CSRSort(csr);
106
  const auto& ctx = csr.indptr->ctx;
107
  cudaStream_t stream = runtime::getCurrentCUDAStream();
108
  auto device = runtime::DeviceAPI::Get(ctx);
109
110
111
112
  // We allocate a workspace of num_rows bytes. It wastes a little bit memory
  // but should be fine.
  int8_t* flags =
      static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
113
  const int nt = dgl::cuda::FindNumThreads(csr.num_rows);
114
  const int nb = (csr.num_rows + nt - 1) / nt;
115
116
117
  CUDA_KERNEL_CALL(
      _SegmentHasNoDuplicate, nb, nt, 0, stream, csr.indptr.Ptr<IdType>(),
      csr.indices.Ptr<IdType>(), csr.num_rows, flags);
118
  bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx);
119
120
121
122
  device->FreeWorkspace(ctx, flags);
  return !ret;
}

123
124
template bool CSRHasDuplicate<kDGLCUDA, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCUDA, int64_t>(CSRMatrix csr);
125
126
127

///////////////////////////// CSRGetRowNNZ /////////////////////////////

128
template <DGLDeviceType XPU, typename IdType>
129
130
131
132
133
134
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
  const IdType cur = aten::IndexSelect<IdType>(csr.indptr, row);
  const IdType next = aten::IndexSelect<IdType>(csr.indptr, row + 1);
  return next - cur;
}

135
136
template int64_t CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
137
138
139

template <typename IdType>
__global__ void _CSRGetRowNNZKernel(
140
    const IdType* vid, const IdType* indptr, IdType* out, int64_t length) {
141
142
143
144
145
146
147
148
149
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    const IdType vv = vid[tx];
    out[tx] = indptr[vv + 1] - indptr[vv];
    tx += stride_x;
  }
}

150
template <DGLDeviceType XPU, typename IdType>
151
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
152
  cudaStream_t stream = runtime::getCurrentCUDAStream();
153
  const auto len = rows->shape[0];
154
  const IdType* vid_data = rows.Ptr<IdType>();
155
156
  const IdType* indptr_data =
      static_cast<IdType*>(GetDevicePointer(csr.indptr));
157
158
  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
  IdType* rst_data = static_cast<IdType*>(rst->data);
159
  const int nt = dgl::cuda::FindNumThreads(len);
160
  const int nb = (len + nt - 1) / nt;
161
162
163
  CUDA_KERNEL_CALL(
      _CSRGetRowNNZKernel, nb, nt, 0, stream, vid_data, indptr_data, rst_data,
      len);
164
165
166
  return rst;
}

167
168
template NDArray CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, NDArray);
169

170
////////////////////////// CSRGetRowColumnIndices //////////////////////////////
171

172
template <DGLDeviceType XPU, typename IdType>
173
174
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
175
176
  const int64_t offset =
      aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
177
178
179
  return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}

180
181
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
182
183
184

///////////////////////////// CSRGetRowData /////////////////////////////

185
template <DGLDeviceType XPU, typename IdType>
186
187
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
188
189
  const int64_t offset =
      aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
190
191
192
  if (aten::CSRHasData(csr))
    return csr.data.CreateView({len}, csr.data->dtype, offset);
  else
193
194
    return aten::Range(
        offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
195
196
}

197
198
template NDArray CSRGetRowData<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
199
200
201

///////////////////////////// CSRSliceRows /////////////////////////////

202
template <DGLDeviceType XPU, typename IdType>
203
204
205
206
207
208
209
210
211
212
213
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
  const int64_t num_rows = end - start;
  const IdType st_pos = aten::IndexSelect<IdType>(csr.indptr, start);
  const IdType ed_pos = aten::IndexSelect<IdType>(csr.indptr, end);
  const IdType nnz = ed_pos - st_pos;
  IdArray ret_indptr = aten::IndexSelect(csr.indptr, start, end + 1) - st_pos;
  // indices and data can be view arrays
  IdArray ret_indices = csr.indices.CreateView(
      {nnz}, csr.indices->dtype, st_pos * sizeof(IdType));
  IdArray ret_data;
  if (CSRHasData(csr))
214
215
    ret_data =
        csr.data.CreateView({nnz}, csr.data->dtype, st_pos * sizeof(IdType));
216
  else
217
218
219
220
    ret_data =
        aten::Range(st_pos, ed_pos, csr.indptr->dtype.bits, csr.indptr->ctx);
  return CSRMatrix(
      num_rows, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);
221
222
}

223
224
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
225

226
/**
227
 * @brief Copy data segment to output buffers
228
 *
229
230
231
 * For the i^th row r = row[i], copy the data from indptr[r] ~ indptr[r+1]
 * to the out_data from out_indptr[i] ~ out_indptr[i+1]
 *
232
233
 * If the provided `data` array is nullptr, write the read index to the
 * out_data.
234
235
236
237
 *
 */
template <typename IdType, typename DType>
__global__ void _SegmentCopyKernel(
238
239
    const IdType* indptr, const DType* data, const IdType* row, int64_t length,
    int64_t n_row, const IdType* out_indptr, DType* out_data) {
240
  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
241
242
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
243
    IdType rpos = dgl::cuda::_UpperBound(out_indptr, n_row, tx) - 1;
244
245
    IdType rofs = tx - out_indptr[rpos];
    const IdType u = row[rpos];
246
    out_data[tx] = data ? data[indptr[u] + rofs] : indptr[u] + rofs;
247
248
249
250
    tx += stride_x;
  }
}

251
template <DGLDeviceType XPU, typename IdType>
252
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
253
  cudaStream_t stream = runtime::getCurrentCUDAStream();
254
255
256
257
  const int64_t len = rows->shape[0];
  IdArray ret_indptr = aten::CumSum(aten::CSRGetRowNNZ(csr, rows), true);
  const int64_t nnz = aten::IndexSelect<IdType>(ret_indptr, len);

258
259
260
  const int nt = 256;  // for better GPU usage of small invocations
  const int nb = (nnz + nt - 1) / nt;

261
  // Copy indices.
262
  IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
263

264
265
266
267
268
269
270
  const IdType* indptr_data =
      static_cast<IdType*>(GetDevicePointer(csr.indptr));
  const IdType* indices_data =
      static_cast<IdType*>(GetDevicePointer(csr.indices));
  const IdType* data_data =
      CSRHasData(csr) ? static_cast<IdType*>(GetDevicePointer(csr.data))
                      : nullptr;
271

272
273
274
275
  CUDA_KERNEL_CALL(
      _SegmentCopyKernel, nb, nt, 0, stream, indptr_data, indices_data,
      rows.Ptr<IdType>(), nnz, len, ret_indptr.Ptr<IdType>(),
      ret_indices.Ptr<IdType>());
276
  // Copy data.
277
  IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
278
279
280
281
282
283
  CUDA_KERNEL_CALL(
      _SegmentCopyKernel, nb, nt, 0, stream, indptr_data, data_data,
      rows.Ptr<IdType>(), nnz, len, ret_indptr.Ptr<IdType>(),
      ret_data.Ptr<IdType>());
  return CSRMatrix(
      len, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);
284
285
}

286
287
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, NDArray);
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, NDArray);
288
289
290

///////////////////////////// CSRGetDataAndIndices /////////////////////////////

291
/**
292
 * @brief Generate a 0-1 mask for each index that hits the provided (row, col)
293
 *        index.
294
 *
295
296
297
298
299
300
301
302
303
304
305
 * Examples:
 * Given a CSR matrix (with duplicate entries) as follows:
 * [[0, 1, 2, 0, 0],
 *  [1, 0, 0, 0, 0],
 *  [0, 0, 1, 1, 0],
 *  [0, 0, 0, 0, 0]]
 * Given rows: [0, 1], cols: [0, 2, 3]
 * The result mask is: [0, 1, 1, 1, 0, 0]
 */
template <typename IdType>
__global__ void _SegmentMaskKernel(
306
307
308
    const IdType* indptr, const IdType* indices, const IdType* row,
    const IdType* col, int64_t row_stride, int64_t col_stride, int64_t length,
    IdType* mask) {
309
310
311
312
313
314
315
316
317
318
319
320
321
322
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    int rpos = tx * row_stride, cpos = tx * col_stride;
    const IdType r = row[rpos], c = col[cpos];
    for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
      if (indices[i] == c) {
        mask[i] = 1;
      }
    }
    tx += stride_x;
  }
}

323
/**
324
 * @brief Search for the insertion positions for needle in the hay.
325
326
327
328
329
 *
 * The hay is a list of sorted elements and the result is the insertion position
 * of each needle so that the insertion still gives sorted order.
 *
 * It essentially perform binary search to find lower bound for each needle
330
331
332
 * elements. Require the largest elements in the hay is larger than the given
 * needle elements. Commonly used in searching for row IDs of a given set of
 * coordinates.
333
334
335
 */
template <typename IdType>
__global__ void _SortedSearchKernel(
336
337
    const IdType* hay, int64_t hay_size, const IdType* needles,
    int64_t num_needles, IdType* pos) {
338
339
340
341
342
343
344
345
346
347
348
349
350
351
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_needles) {
    const IdType ele = needles[tx];
    // binary search
    IdType lo = 0, hi = hay_size - 1;
    while (lo < hi) {
      IdType mid = (lo + hi) >> 1;
      if (hay[mid] <= ele) {
        lo = mid + 1;
      } else {
        hi = mid;
      }
    }
352
    pos[tx] = (hay[hi] == ele) ? hi : hi - 1;
353
354
355
356
    tx += stride_x;
  }
}

357
template <DGLDeviceType XPU, typename IdType>
358
359
std::vector<NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, NDArray row, NDArray col) {
360
361
362
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto len = std::max(rowlen, collen);
363
  if (len == 0) return {NullArray(), NullArray(), NullArray()};
364
365
366
367
368
369

  const auto& ctx = row->ctx;
  const auto nbits = row->dtype.bits;
  const int64_t nnz = csr.indices->shape[0];
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
370
  cudaStream_t stream = runtime::getCurrentCUDAStream();
371

372
373
374
375
  const IdType* indptr_data =
      static_cast<IdType*>(GetDevicePointer(csr.indptr));
  const IdType* indices_data =
      static_cast<IdType*>(GetDevicePointer(csr.indices));
376

377
378
  // Generate a 0-1 mask for matched (row, col) positions.
  IdArray mask = Full(0, nnz, nbits, ctx);
379
  const int nt = dgl::cuda::FindNumThreads(len);
380
  const int nb = (len + nt - 1) / nt;
381
382
383
  CUDA_KERNEL_CALL(
      _SegmentMaskKernel, nb, nt, 0, stream, indptr_data, indices_data,
      row.Ptr<IdType>(), col.Ptr<IdType>(), row_stride, col_stride, len,
384
385
386
387
388
389
390
391
392
      mask.Ptr<IdType>());

  IdArray idx = AsNumBits(NonZero(mask), nbits);
  if (idx->shape[0] == 0)
    // No data. Return three empty arrays.
    return {idx, idx, idx};

  // Search for row index
  IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits);
393
  const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]);
394
  const int nb2 = (idx->shape[0] + nt - 1) / nt;
395
396
397
  CUDA_KERNEL_CALL(
      _SortedSearchKernel, nb2, nt2, 0, stream, indptr_data, csr.num_rows,
      idx.Ptr<IdType>(), idx->shape[0], ret_row.Ptr<IdType>());
398
399
400

  // Column & data can be obtained by index select.
  IdArray ret_col = IndexSelect(csr.indices, idx);
401
  IdArray ret_data = CSRHasData(csr) ? IndexSelect(csr.data, idx) : idx;
402
403
404
  return {ret_row, ret_col, ret_data};
}

405
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int32_t>(
406
    CSRMatrix csr, NDArray rows, NDArray cols);
407
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>(
408
409
410
411
    CSRMatrix csr, NDArray rows, NDArray cols);

///////////////////////////// CSRSliceMatrix /////////////////////////////

412
413
414
415
416
int64_t _UpPower(int64_t numel) {
  uint64_t ret = 1 << static_cast<uint64_t>(std::log2(numel) + 1);
  return ret;
}

417
/**
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
 * @brief Thomas Wang's 32 bit Mix Function.
 * Source link: https://gist.github.com/badboy/6267743
 */
__device__ inline uint32_t _Hash32Shift(uint32_t key) {
  key = ~key + (key << 15);
  key = key ^ (key >> 12);
  key = key + (key << 2);
  key = key ^ (key >> 4);
  key = key * 2057;
  key = key ^ (key >> 16);
  return key;
}

/**
 * @brief Thomas Wang's 64 bit Mix Function.
 * Source link: https://gist.github.com/badboy/6267743
 */
__device__ inline uint64_t _Hash64Shift(uint64_t key) {
  key = (~key) + (key << 21);
  key = key ^ (key >> 24);
  key = (key + (key << 3)) + (key << 8);
  key = key ^ (key >> 14);
  key = (key + (key << 2)) + (key << 4);
  key = key ^ (key >> 28);
  key = key + (key << 31);
  return key;
}

/**
 * @brief A hashmap designed for CSRSliceMatrix, similar in function to set. For
 * performance, it can only be created and called in the cuda kernel.
449
450
 */
template <typename IdType>
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
struct NodeQueryHashmap {
  __device__ inline NodeQueryHashmap(IdType* Kptr, size_t numel)
      : kptr_(Kptr), capacity_(numel) {}

  /**
   * @brief Insert a key. It must be called by cuda threads.
   *
   * @param key The key to be inserted.
   */
  __device__ inline void Insert(IdType key) {
    uint32_t delta = 1;
    uint32_t pos = Hash(key);
    IdType prev = dgl::aten::cuda::AtomicCAS(&kptr_[pos], kEmptyKey_, key);
    while (prev != key && prev != kEmptyKey_) {
      pos = Hash(pos + delta);
      delta += 1;
      prev = dgl::aten::cuda::AtomicCAS(&kptr_[pos], kEmptyKey_, key);
    }
  }

  /**
   * @brief Check whether a key exists within the hashtable. It must be called
   * by cuda threads.
   *
   * @param key The key to check for.
   * @return True if the key exists in the hashtable.
   */
  __device__ inline bool Query(IdType key) {
    uint32_t delta = 1;
    uint32_t pos = Hash(key);
    while (true) {
      if (kptr_[pos] == key) return true;
      if (kptr_[pos] == kEmptyKey_) return false;
      pos = Hash(pos + delta);
      delta += 1;
    }
    return false;
  }

  __device__ inline uint32_t Hash(int32_t key) {
    return _Hash32Shift(key) & (capacity_ - 1);
  }

  __device__ inline uint32_t Hash(uint32_t key) {
    return _Hash32Shift(key) & (capacity_ - 1);
  }

  __device__ inline uint32_t Hash(int64_t key) {
    return static_cast<uint32_t>(_Hash64Shift(key)) & (capacity_ - 1);
  }

  __device__ inline uint32_t Hash(uint64_t key) {
    return static_cast<uint32_t>(_Hash64Shift(key)) & (capacity_ - 1);
  }

  IdType kEmptyKey_{-1};
  IdType* kptr_;
  uint32_t capacity_{0};
};

/**
 * @brief Generate a 0-1 mask for each index whose column is in the provided
 * hashmap. It also counts the number of masked values per row.
 *
 * @tparam IdType The ID type used for matrices.
 * @tparam WARP_SIZE The number of cuda threads in a cuda warp.
 * @tparam BLOCK_WARPS The number of warps in a cuda block.
 * @tparam TILE_SIZE The number of rows covered by each threadblock.
 */
template <typename IdType, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
521
__global__ void _SegmentMaskColKernel(
522
    const IdType* indptr, const IdType* indices, int64_t num_rows,
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    IdType* hashmap_buffer, int64_t buffer_size, IdType* mask, IdType* count) {
  assert(blockDim.x == WARP_SIZE);
  assert(blockDim.y == BLOCK_WARPS);

  int warp_id = threadIdx.y;
  int laneid = threadIdx.x;
  IdType out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
  IdType last_row =
      min(static_cast<IdType>((blockIdx.x + 1) * TILE_SIZE),
          static_cast<IdType>(num_rows));

  NodeQueryHashmap<IdType> hashmap(hashmap_buffer, buffer_size);
  typedef cub::WarpReduce<IdType> WarpReduce;
  __shared__ typename WarpReduce::TempStorage temp_storage[BLOCK_WARPS];

  while (out_row < last_row) {
    IdType local_count = 0;
    IdType in_row_start = indptr[out_row];
    IdType in_row_end = indptr[out_row + 1];
    for (int idx = in_row_start + laneid; idx < in_row_end; idx += WARP_SIZE) {
      bool is_in = hashmap.Query(indices[idx]);
      if (is_in) {
        local_count += 1;
        mask[idx] = 1;
      }
548
    }
549
550
551
552
553
    IdType reduce_count = WarpReduce(temp_storage[warp_id]).Sum(local_count);
    if (laneid == 0) {
      count[out_row] = reduce_count;
    }
    out_row += BLOCK_WARPS;
554
555
556
  }
}

557
template <DGLDeviceType XPU, typename IdType>
558
559
CSRMatrix CSRSliceMatrix(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
560
  cudaStream_t stream = runtime::getCurrentCUDAStream();
561
562
563
564
565
566
567
  const auto& ctx = rows->ctx;
  const auto& dtype = rows->dtype;
  const auto nbits = dtype.bits;
  const int64_t new_nrows = rows->shape[0];
  const int64_t new_ncols = cols->shape[0];

  if (new_nrows == 0 || new_ncols == 0)
568
569
570
    return CSRMatrix(
        new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),
        NullArray(dtype, ctx), NullArray(dtype, ctx));
571
572
573
574
575

  // First slice rows
  csr = CSRSliceRows(csr, rows);

  if (csr.indices->shape[0] == 0)
576
577
578
    return CSRMatrix(
        new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),
        NullArray(dtype, ctx), NullArray(dtype, ctx));
579
580
581
582
583

  // Generate a 0-1 mask for matched (row, col) positions.
  IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx);
  // A count for how many masked values per row.
  IdArray count = NewIdArray(csr.num_rows, ctx, nbits);
584
585
  CUDA_CALL(
      cudaMemset(count.Ptr<IdType>(), 0, sizeof(IdType) * (csr.num_rows)));
586

587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
  // Generate a NodeQueryHashmap buffer. The key of the hashmap is col.
  // For performance, the load factor of the hashmap is in (0.25, 0.5);
  // Because num_cols is usually less than 1 Million (on GPU), the
  // memory overhead is not significant (less than 31MB) at a low load factor.
  int64_t buffer_size = _UpPower(new_ncols) * 2;
  IdArray hashmap_buffer = Full(-1, buffer_size, nbits, ctx);

  using it = thrust::counting_iterator<int64_t>;
  runtime::CUDAWorkspaceAllocator allocator(ctx);
  const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
  thrust::for_each(
      exec_policy, it(0), it(new_ncols),
      [key = cols.Ptr<IdType>(), buffer = hashmap_buffer.Ptr<IdType>(),
       buffer_size] __device__(int64_t i) {
        NodeQueryHashmap<IdType> hashmap(buffer, buffer_size);
        hashmap.Insert(key[i]);
      });
604

605
606
607
608
  const IdType* indptr_data =
      static_cast<IdType*>(GetDevicePointer(csr.indptr));
  const IdType* indices_data =
      static_cast<IdType*>(GetDevicePointer(csr.indices));
609

610
  // Execute SegmentMaskColKernel
611
612
613
614
615
616
617
618
619
  const int64_t num_rows = csr.num_rows;
  constexpr int WARP_SIZE = 32;
  // With a simple fine-tuning, TILE_SIZE=16 gives a good performance.
  constexpr int TILE_SIZE = 16;
  constexpr int BLOCK_WARPS = CUDA_MAX_NUM_THREADS / WARP_SIZE;
  IdType nb =
      dgl::cuda::FindNumBlocks<'x'>((num_rows + TILE_SIZE - 1) / TILE_SIZE);
  const dim3 nthrs(WARP_SIZE, BLOCK_WARPS);
  const dim3 nblks(nb);
620
  CUDA_KERNEL_CALL(
621
622
623
      (_SegmentMaskColKernel<IdType, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>), nblks,
      nthrs, 0, stream, indptr_data, indices_data, num_rows,
      hashmap_buffer.Ptr<IdType>(), buffer_size, mask.Ptr<IdType>(),
624
      count.Ptr<IdType>());
625
626
627

  IdArray idx = AsNumBits(NonZero(mask), nbits);
  if (idx->shape[0] == 0)
628
629
630
    return CSRMatrix(
        new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),
        NullArray(dtype, ctx), NullArray(dtype, ctx));
631
632
633
634
635
636

  // Indptr needs to be adjusted according to the new nnz per row.
  IdArray ret_indptr = CumSum(count, true);

  // Column & data can be obtained by index select.
  IdArray ret_col = IndexSelect(csr.indices, idx);
637
  IdArray ret_data = CSRHasData(csr) ? IndexSelect(csr.data, idx) : idx;
638
639
640
641
642
643

  // Relabel column
  IdArray col_hash = NewIdArray(csr.num_cols, ctx, nbits);
  Scatter_(cols, Range(0, cols->shape[0], nbits, ctx), col_hash);
  ret_col = IndexSelect(col_hash, ret_col);

644
  return CSRMatrix(new_nrows, new_ncols, ret_indptr, ret_col, ret_data);
645
646
}

647
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int32_t>(
648
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
649
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int64_t>(
650
651
652
653
654
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

}  // namespace impl
}  // namespace aten
}  // namespace dgl