"vscode:/vscode.git/clone" did not exist on "179712024ec757bdc7fa6f91d574a7592015e394"
spmat_op_impl_csr.cu 20.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/spmat_op_impl_csr.cu
 * \brief CSR operator CPU implementation
 */
#include <dgl/array.h>
#include <vector>
#include <unordered_set>
#include <numeric>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
12
13
#include "./atomic.cuh"
#include "./dgl_cub.cuh"
14
15
16
17
18
19
20
21
22
23

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

///////////////////////////// CSRIsNonZero /////////////////////////////

24
template <DGLDeviceType XPU, typename IdType>
25
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
26
  cudaStream_t stream = runtime::getCurrentCUDAStream();
27
28
29
30
31
32
33
34
  const auto& ctx = csr.indptr->ctx;
  IdArray rows = aten::VecToIdArray<int64_t>({row}, sizeof(IdType) * 8, ctx);
  IdArray cols = aten::VecToIdArray<int64_t>({col}, sizeof(IdType) * 8, ctx);
  rows = rows.CopyTo(ctx);
  cols = cols.CopyTo(ctx);
  IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8);
  const IdType* data = nullptr;
  // TODO(minjie): use binary search for sorted csr
35
  CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel,
36
      1, 1, 0, stream,
37
38
39
      csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
      rows.Ptr<IdType>(), cols.Ptr<IdType>(),
      1, 1, 1,
40
      static_cast<IdType*>(nullptr), static_cast<IdType>(-1), out.Ptr<IdType>());
41
  out = out.CopyTo(DGLContext{kDGLCPU, 0});
42
43
44
  return *out.Ptr<IdType>() != -1;
}

45
46
template bool CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
47

48
template <DGLDeviceType XPU, typename IdType>
49
50
51
52
53
54
55
56
57
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto rstlen = std::max(rowlen, collen);
  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
  if (rstlen == 0)
    return rst;
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
58
  cudaStream_t stream = runtime::getCurrentCUDAStream();
59
  const int nt = dgl::cuda::FindNumThreads(rstlen);
60
61
  const int nb = (rstlen + nt - 1) / nt;
  const IdType* data = nullptr;
62
63
64
65
66
67
68
69
  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  const IdType* indices_data = csr.indices.Ptr<IdType>();
  if (csr.is_pinned) {
    CUDA_CALL(cudaHostGetDevicePointer(
        &indptr_data, csr.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(cudaHostGetDevicePointer(
        &indices_data, csr.indices.Ptr<IdType>(), 0));
  }
70
  // TODO(minjie): use binary search for sorted csr
71
  CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel,
72
      nb, nt, 0, stream,
73
      indptr_data, indices_data, data,
74
75
      row.Ptr<IdType>(), col.Ptr<IdType>(),
      row_stride, col_stride, rstlen,
76
      static_cast<IdType*>(nullptr), static_cast<IdType>(-1), rst.Ptr<IdType>());
77
78
79
  return rst != -1;
}

80
81
template NDArray CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, NDArray, NDArray);
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

///////////////////////////// CSRHasDuplicate /////////////////////////////

/*!
 * \brief Check whether each row does not have any duplicate entries.
 * Assume the CSR is sorted.
 */
template <typename IdType>
__global__ void _SegmentHasNoDuplicate(
    const IdType* indptr, const IdType* indices,
    int64_t num_rows, int8_t* flags) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_rows) {
    bool f = true;
    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {
      f = (indices[i - 1] != indices[i]);
    }
    flags[tx] = static_cast<int8_t>(f);
    tx += stride_x;
  }
}


106
template <DGLDeviceType XPU, typename IdType>
107
108
109
110
bool CSRHasDuplicate(CSRMatrix csr) {
  if (!csr.sorted)
    csr = CSRSort(csr);
  const auto& ctx = csr.indptr->ctx;
111
  cudaStream_t stream = runtime::getCurrentCUDAStream();
112
113
114
115
  auto device = runtime::DeviceAPI::Get(ctx);
  // We allocate a workspace of num_rows bytes. It wastes a little bit memory but should
  // be fine.
  int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
116
  const int nt = dgl::cuda::FindNumThreads(csr.num_rows);
117
  const int nb = (csr.num_rows + nt - 1) / nt;
118
  CUDA_KERNEL_CALL(_SegmentHasNoDuplicate,
119
      nb, nt, 0, stream,
120
121
      csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
      csr.num_rows, flags);
122
  bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx);
123
124
125
126
  device->FreeWorkspace(ctx, flags);
  return !ret;
}

127
128
template bool CSRHasDuplicate<kDGLCUDA, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCUDA, int64_t>(CSRMatrix csr);
129
130
131

///////////////////////////// CSRGetRowNNZ /////////////////////////////

132
template <DGLDeviceType XPU, typename IdType>
133
134
135
136
137
138
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
  const IdType cur = aten::IndexSelect<IdType>(csr.indptr, row);
  const IdType next = aten::IndexSelect<IdType>(csr.indptr, row + 1);
  return next - cur;
}

139
140
template int64_t CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

template <typename IdType>
__global__ void _CSRGetRowNNZKernel(
    const IdType* vid,
    const IdType* indptr,
    IdType* out,
    int64_t length) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    const IdType vv = vid[tx];
    out[tx] = indptr[vv + 1] - indptr[vv];
    tx += stride_x;
  }
}

157
template <DGLDeviceType XPU, typename IdType>
158
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
159
  cudaStream_t stream = runtime::getCurrentCUDAStream();
160
  const auto len = rows->shape[0];
161
162
163
164
165
166
  const IdType* vid_data = rows.Ptr<IdType>();
  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  if (csr.is_pinned) {
    CUDA_CALL(cudaHostGetDevicePointer(
        &indptr_data, csr.indptr.Ptr<IdType>(), 0));
  }
167
168
  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
  IdType* rst_data = static_cast<IdType*>(rst->data);
169
  const int nt = dgl::cuda::FindNumThreads(len);
170
  const int nb = (len + nt - 1) / nt;
171
  CUDA_KERNEL_CALL(_CSRGetRowNNZKernel,
172
      nb, nt, 0, stream,
173
174
175
176
      vid_data, indptr_data, rst_data, len);
  return rst;
}

177
178
template NDArray CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, NDArray);
179
180
181

///////////////////////////// CSRGetRowColumnIndices /////////////////////////////

182
template <DGLDeviceType XPU, typename IdType>
183
184
185
186
187
188
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  const int64_t offset = aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
  return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}

189
190
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
191
192
193

///////////////////////////// CSRGetRowData /////////////////////////////

194
template <DGLDeviceType XPU, typename IdType>
195
196
197
198
199
200
201
202
203
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  const int64_t offset = aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
  if (aten::CSRHasData(csr))
    return csr.data.CreateView({len}, csr.data->dtype, offset);
  else
    return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
}

204
205
template NDArray CSRGetRowData<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
206
207
208

///////////////////////////// CSRSliceRows /////////////////////////////

209
template <DGLDeviceType XPU, typename IdType>
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
  const int64_t num_rows = end - start;
  const IdType st_pos = aten::IndexSelect<IdType>(csr.indptr, start);
  const IdType ed_pos = aten::IndexSelect<IdType>(csr.indptr, end);
  const IdType nnz = ed_pos - st_pos;
  IdArray ret_indptr = aten::IndexSelect(csr.indptr, start, end + 1) - st_pos;
  // indices and data can be view arrays
  IdArray ret_indices = csr.indices.CreateView(
      {nnz}, csr.indices->dtype, st_pos * sizeof(IdType));
  IdArray ret_data;
  if (CSRHasData(csr))
    ret_data = csr.data.CreateView({nnz}, csr.data->dtype, st_pos * sizeof(IdType));
  else
    ret_data = aten::Range(st_pos, ed_pos,
                           csr.indptr->dtype.bits, csr.indptr->ctx);
  return CSRMatrix(num_rows, csr.num_cols,
                   ret_indptr, ret_indices, ret_data,
                   csr.sorted);
}

230
231
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
232
233
234
235
236
237
238
239
240
241
242
243
244

/*!
 * \brief Copy data segment to output buffers
 * 
 * For the i^th row r = row[i], copy the data from indptr[r] ~ indptr[r+1]
 * to the out_data from out_indptr[i] ~ out_indptr[i+1]
 *
 * If the provided `data` array is nullptr, write the read index to the out_data.
 *
 */
template <typename IdType, typename DType>
__global__ void _SegmentCopyKernel(
    const IdType* indptr, const DType* data,
245
    const IdType* row, int64_t length, int64_t n_row,
246
    const IdType* out_indptr, DType* out_data) {
247
  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
248
249
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
250
    IdType rpos = dgl::cuda::_UpperBound(out_indptr, n_row, tx) - 1;
251
252
253
    IdType rofs = tx - out_indptr[rpos];
    const IdType u = row[rpos];
    out_data[tx] = data? data[indptr[u]+rofs] : indptr[u]+rofs;
254
255
256
257
    tx += stride_x;
  }
}

258
template <DGLDeviceType XPU, typename IdType>
259
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
260
  cudaStream_t stream = runtime::getCurrentCUDAStream();
261
262
263
264
  const int64_t len = rows->shape[0];
  IdArray ret_indptr = aten::CumSum(aten::CSRGetRowNNZ(csr, rows), true);
  const int64_t nnz = aten::IndexSelect<IdType>(ret_indptr, len);

265
266
267
  const int nt = 256;  // for better GPU usage of small invocations
  const int nb = (nnz + nt - 1) / nt;

268
  // Copy indices.
269
  IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  const IdType* indices_data = csr.indices.Ptr<IdType>();
  const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr;
  if (csr.is_pinned) {
    CUDA_CALL(cudaHostGetDevicePointer(
        &indptr_data, csr.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(cudaHostGetDevicePointer(
        &indices_data, csr.indices.Ptr<IdType>(), 0));
    if (CSRHasData(csr)) {
      CUDA_CALL(cudaHostGetDevicePointer(
          &data_data, csr.data.Ptr<IdType>(), 0));
    }
  }

285
  CUDA_KERNEL_CALL(_SegmentCopyKernel,
286
      nb, nt, 0, stream,
287
      indptr_data, indices_data,
288
      rows.Ptr<IdType>(), nnz, len,
289
290
      ret_indptr.Ptr<IdType>(), ret_indices.Ptr<IdType>());
  // Copy data.
291
  IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
292
  CUDA_KERNEL_CALL(_SegmentCopyKernel,
293
      nb, nt, 0, stream,
294
      indptr_data, data_data,
295
      rows.Ptr<IdType>(), nnz, len,
296
297
298
299
300
301
      ret_indptr.Ptr<IdType>(), ret_data.Ptr<IdType>());
  return CSRMatrix(len, csr.num_cols,
                   ret_indptr, ret_indices, ret_data,
                   csr.sorted);
}

302
303
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix , NDArray);
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346

///////////////////////////// CSRGetDataAndIndices /////////////////////////////

/*!
 * \brief Generate a 0-1 mask for each index that hits the provided (row, col)
 *        index.
 * 
 * Examples:
 * Given a CSR matrix (with duplicate entries) as follows:
 * [[0, 1, 2, 0, 0],
 *  [1, 0, 0, 0, 0],
 *  [0, 0, 1, 1, 0],
 *  [0, 0, 0, 0, 0]]
 * Given rows: [0, 1], cols: [0, 2, 3]
 * The result mask is: [0, 1, 1, 1, 0, 0]
 */
template <typename IdType>
__global__ void _SegmentMaskKernel(
    const IdType* indptr, const IdType* indices,
    const IdType* row, const IdType* col,
    int64_t row_stride, int64_t col_stride,
    int64_t length, IdType* mask) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    int rpos = tx * row_stride, cpos = tx * col_stride;
    const IdType r = row[rpos], c = col[cpos];
    for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
      if (indices[i] == c) {
        mask[i] = 1;
      }
    }
    tx += stride_x;
  }
}

/*!
 * \brief Search for the insertion positions for needle in the hay.
 *
 * The hay is a list of sorted elements and the result is the insertion position
 * of each needle so that the insertion still gives sorted order.
 *
 * It essentially perform binary search to find lower bound for each needle
347
348
349
 * elements. Require the largest elements in the hay is larger than the given
 * needle elements. Commonly used in searching for row IDs of a given set of
 * coordinates.
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
 */
template <typename IdType>
__global__ void _SortedSearchKernel(
    const IdType* hay, int64_t hay_size,
    const IdType* needles, int64_t num_needles,
    IdType* pos) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_needles) {
    const IdType ele = needles[tx];
    // binary search
    IdType lo = 0, hi = hay_size - 1;
    while (lo < hi) {
      IdType mid = (lo + hi) >> 1;
      if (hay[mid] <= ele) {
        lo = mid + 1;
      } else {
        hi = mid;
      }
    }
    pos[tx] = (hay[hi] == ele)? hi : hi - 1;
    tx += stride_x;
  }
}

375
template <DGLDeviceType XPU, typename IdType>
376
377
378
379
380
381
382
383
384
385
386
387
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray col) {
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto len = std::max(rowlen, collen);
  if (len == 0)
    return {NullArray(), NullArray(), NullArray()};

  const auto& ctx = row->ctx;
  const auto nbits = row->dtype.bits;
  const int64_t nnz = csr.indices->shape[0];
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
388
  cudaStream_t stream = runtime::getCurrentCUDAStream();
389

390
391
392
393
394
395
396
397
398
  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  const IdType* indices_data = csr.indices.Ptr<IdType>();
  if (csr.is_pinned) {
    CUDA_CALL(cudaHostGetDevicePointer(
        &indptr_data, csr.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(cudaHostGetDevicePointer(
        &indices_data, csr.indices.Ptr<IdType>(), 0));
  }

399
400
  // Generate a 0-1 mask for matched (row, col) positions.
  IdArray mask = Full(0, nnz, nbits, ctx);
401
  const int nt = dgl::cuda::FindNumThreads(len);
402
  const int nb = (len + nt - 1) / nt;
403
  CUDA_KERNEL_CALL(_SegmentMaskKernel,
404
      nb, nt, 0, stream,
405
      indptr_data, indices_data,
406
407
408
409
410
411
412
413
414
415
416
      row.Ptr<IdType>(), col.Ptr<IdType>(),
      row_stride, col_stride, len,
      mask.Ptr<IdType>());

  IdArray idx = AsNumBits(NonZero(mask), nbits);
  if (idx->shape[0] == 0)
    // No data. Return three empty arrays.
    return {idx, idx, idx};

  // Search for row index
  IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits);
417
  const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]);
418
  const int nb2 = (idx->shape[0] + nt - 1) / nt;
419
  CUDA_KERNEL_CALL(_SortedSearchKernel,
420
      nb2, nt2, 0, stream,
421
      indptr_data, csr.num_rows,
422
423
424
425
426
427
428
429
430
      idx.Ptr<IdType>(), idx->shape[0],
      ret_row.Ptr<IdType>());

  // Column & data can be obtained by index select.
  IdArray ret_col = IndexSelect(csr.indices, idx);
  IdArray ret_data = CSRHasData(csr)? IndexSelect(csr.data, idx) : idx;
  return {ret_row, ret_col, ret_data};
}

431
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int32_t>(
432
    CSRMatrix csr, NDArray rows, NDArray cols);
433
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>(
434
435
436
437
438
439
440
441
442
443
    CSRMatrix csr, NDArray rows, NDArray cols);

///////////////////////////// CSRSliceMatrix /////////////////////////////

/*!
 * \brief Generate a 0-1 mask for each index whose column is in the provided set.
 *        It also counts the number of masked values per row.
 */
template <typename IdType>
__global__ void _SegmentMaskColKernel(
444
    const IdType* indptr, const IdType* indices, int64_t num_rows, int64_t num_nnz,
445
446
    const IdType* col, int64_t col_len,
    IdType* mask, IdType* count) {
447
  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
448
  const int stride_x = gridDim.x * blockDim.x;
449
450
451
452
453
454
455
  while (tx < num_nnz) {
    IdType rpos = dgl::cuda::_UpperBound(indptr, num_rows, tx) - 1;
    IdType cur_c = indices[tx];
    IdType i = dgl::cuda::_BinarySearch(col, col_len, cur_c);
    if (i < col_len) {
      mask[tx] = 1;
      cuda::AtomicAdd(count+rpos, IdType(1));
456
457
458
459
460
    }
    tx += stride_x;
  }
}

461
template <DGLDeviceType XPU, typename IdType>
462
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
463
  cudaStream_t stream = runtime::getCurrentCUDAStream();
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
  const auto& ctx = rows->ctx;
  const auto& dtype = rows->dtype;
  const auto nbits = dtype.bits;
  const int64_t new_nrows = rows->shape[0];
  const int64_t new_ncols = cols->shape[0];

  if (new_nrows == 0 || new_ncols == 0)
    return CSRMatrix(new_nrows, new_ncols,
                     Full(0, new_nrows + 1, nbits, ctx),
                     NullArray(dtype, ctx), NullArray(dtype, ctx));

  // First slice rows
  csr = CSRSliceRows(csr, rows);

  if (csr.indices->shape[0] == 0)
    return CSRMatrix(new_nrows, new_ncols,
                     Full(0, new_nrows + 1, nbits, ctx),
                     NullArray(dtype, ctx), NullArray(dtype, ctx));

  // Generate a 0-1 mask for matched (row, col) positions.
  IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx);
  // A count for how many masked values per row.
  IdArray count = NewIdArray(csr.num_rows, ctx, nbits);
487
488
489
490
491
492
493
494
495
496
497
498
499
500
  CUDA_CALL(cudaMemset(count.Ptr<IdType>(), 0, sizeof(IdType) * (csr.num_rows)));

  const int64_t nnz_csr = csr.indices->shape[0];
  const int nt = 256;

  // In general ``cols'' array is sorted. But it is not guaranteed.
  // Hence checking and sorting array first. Sorting is not in place.
  auto device = runtime::DeviceAPI::Get(ctx);
  auto cols_size = cols->shape[0];

  IdArray sorted_array = NewIdArray(cols->shape[0], ctx, cols->dtype.bits);
  auto ptr_sorted_cols = sorted_array.Ptr<IdType>();
  auto ptr_cols = cols.Ptr<IdType>();
  size_t workspace_size = 0;
501
  CUDA_CALL(cub::DeviceRadixSort::SortKeys(
502
       nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0],
503
       0, sizeof(IdType)*8, stream));
504
  void *workspace = device->AllocWorkspace(ctx, workspace_size);
505
  CUDA_CALL(cub::DeviceRadixSort::SortKeys(
506
       workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0],
507
       0, sizeof(IdType)*8, stream));
508
509
  device->FreeWorkspace(ctx, workspace);

510
511
512
513
514
515
516
517
518
  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  const IdType* indices_data = csr.indices.Ptr<IdType>();
  if (csr.is_pinned) {
    CUDA_CALL(cudaHostGetDevicePointer(
        &indptr_data, csr.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(cudaHostGetDevicePointer(
        &indices_data, csr.indices.Ptr<IdType>(), 0));
  }

519
520
  // Execute SegmentMaskColKernel
  int nb = (nnz_csr + nt - 1) / nt;
521
  CUDA_KERNEL_CALL(_SegmentMaskColKernel,
522
      nb, nt, 0, stream,
523
      indptr_data, indices_data, csr.num_rows, nnz_csr,
524
      ptr_sorted_cols, cols_size,
525
      mask.Ptr<IdType>(), count.Ptr<IdType>());
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548

  IdArray idx = AsNumBits(NonZero(mask), nbits);
  if (idx->shape[0] == 0)
    return CSRMatrix(new_nrows, new_ncols,
                     Full(0, new_nrows + 1, nbits, ctx),
                     NullArray(dtype, ctx), NullArray(dtype, ctx));

  // Indptr needs to be adjusted according to the new nnz per row.
  IdArray ret_indptr = CumSum(count, true);

  // Column & data can be obtained by index select.
  IdArray ret_col = IndexSelect(csr.indices, idx);
  IdArray ret_data = CSRHasData(csr)? IndexSelect(csr.data, idx) : idx;

  // Relabel column
  IdArray col_hash = NewIdArray(csr.num_cols, ctx, nbits);
  Scatter_(cols, Range(0, cols->shape[0], nbits, ctx), col_hash);
  ret_col = IndexSelect(col_hash, ret_col);

  return CSRMatrix(new_nrows, new_ncols, ret_indptr,
                   ret_col, ret_data);
}

549
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int32_t>(
550
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
551
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int64_t>(
552
553
554
555
556
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

}  // namespace impl
}  // namespace aten
}  // namespace dgl