"tests/vscode:/vscode.git/clone" did not exist on "762b4fd80f48488ef57b18f8c0143b0bb406dbd2"
spmat_op_impl_csr.cu 19.9 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/spmat_op_impl_csr.cu
 * @brief CSR operator CPU implementation
5
6
 */
#include <dgl/array.h>
7

8
#include <numeric>
9
10
11
#include <unordered_set>
#include <vector>

12
#include "../../runtime/cuda/cuda_common.h"
13
14
#include "./atomic.cuh"
#include "./dgl_cub.cuh"
15
#include "./utils.h"
16
17
18
19
20
21
22
23
24
25

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

///////////////////////////// CSRIsNonZero /////////////////////////////

26
template <DGLDeviceType XPU, typename IdType>
27
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
28
  cudaStream_t stream = runtime::getCurrentCUDAStream();
29
30
31
32
33
34
35
36
  const auto& ctx = csr.indptr->ctx;
  IdArray rows = aten::VecToIdArray<int64_t>({row}, sizeof(IdType) * 8, ctx);
  IdArray cols = aten::VecToIdArray<int64_t>({col}, sizeof(IdType) * 8, ctx);
  rows = rows.CopyTo(ctx);
  cols = cols.CopyTo(ctx);
  IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8);
  const IdType* data = nullptr;
  // TODO(minjie): use binary search for sorted csr
37
38
39
40
41
  CUDA_KERNEL_CALL(
      dgl::cuda::_LinearSearchKernel, 1, 1, 0, stream, csr.indptr.Ptr<IdType>(),
      csr.indices.Ptr<IdType>(), data, rows.Ptr<IdType>(), cols.Ptr<IdType>(),
      1, 1, 1, static_cast<IdType*>(nullptr), static_cast<IdType>(-1),
      out.Ptr<IdType>());
42
  out = out.CopyTo(DGLContext{kDGLCPU, 0});
43
44
45
  return *out.Ptr<IdType>() != -1;
}

46
47
template bool CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
48

49
template <DGLDeviceType XPU, typename IdType>
50
51
52
53
54
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto rstlen = std::max(rowlen, collen);
  NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
55
  if (rstlen == 0) return rst;
56
57
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
58
  cudaStream_t stream = runtime::getCurrentCUDAStream();
59
  const int nt = dgl::cuda::FindNumThreads(rstlen);
60
61
  const int nb = (rstlen + nt - 1) / nt;
  const IdType* data = nullptr;
62
63
64
  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  const IdType* indices_data = csr.indices.Ptr<IdType>();
  if (csr.is_pinned) {
65
66
67
68
    CUDA_CALL(
        cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(
        cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
69
  }
70
  // TODO(minjie): use binary search for sorted csr
71
72
73
74
75
  CUDA_KERNEL_CALL(
      dgl::cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data,
      indices_data, data, row.Ptr<IdType>(), col.Ptr<IdType>(), row_stride,
      col_stride, rstlen, static_cast<IdType*>(nullptr),
      static_cast<IdType>(-1), rst.Ptr<IdType>());
76
77
78
  return rst != -1;
}

79
80
template NDArray CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, NDArray, NDArray);
81
82
83

///////////////////////////// CSRHasDuplicate /////////////////////////////

84
/**
85
 * @brief Check whether each row does not have any duplicate entries.
86
87
88
89
 * Assume the CSR is sorted.
 */
template <typename IdType>
__global__ void _SegmentHasNoDuplicate(
90
91
    const IdType* indptr, const IdType* indices, int64_t num_rows,
    int8_t* flags) {
92
93
94
95
96
97
98
99
100
101
102
103
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_rows) {
    bool f = true;
    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {
      f = (indices[i - 1] != indices[i]);
    }
    flags[tx] = static_cast<int8_t>(f);
    tx += stride_x;
  }
}

104
template <DGLDeviceType XPU, typename IdType>
105
bool CSRHasDuplicate(CSRMatrix csr) {
106
  if (!csr.sorted) csr = CSRSort(csr);
107
  const auto& ctx = csr.indptr->ctx;
108
  cudaStream_t stream = runtime::getCurrentCUDAStream();
109
  auto device = runtime::DeviceAPI::Get(ctx);
110
111
112
113
  // We allocate a workspace of num_rows bytes. It wastes a little bit memory
  // but should be fine.
  int8_t* flags =
      static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
114
  const int nt = dgl::cuda::FindNumThreads(csr.num_rows);
115
  const int nb = (csr.num_rows + nt - 1) / nt;
116
117
118
  CUDA_KERNEL_CALL(
      _SegmentHasNoDuplicate, nb, nt, 0, stream, csr.indptr.Ptr<IdType>(),
      csr.indices.Ptr<IdType>(), csr.num_rows, flags);
119
  bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx);
120
121
122
123
  device->FreeWorkspace(ctx, flags);
  return !ret;
}

124
125
template bool CSRHasDuplicate<kDGLCUDA, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCUDA, int64_t>(CSRMatrix csr);
126
127
128

///////////////////////////// CSRGetRowNNZ /////////////////////////////

129
template <DGLDeviceType XPU, typename IdType>
130
131
132
133
134
135
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
  const IdType cur = aten::IndexSelect<IdType>(csr.indptr, row);
  const IdType next = aten::IndexSelect<IdType>(csr.indptr, row + 1);
  return next - cur;
}

136
137
template int64_t CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
138
139
140

template <typename IdType>
__global__ void _CSRGetRowNNZKernel(
141
    const IdType* vid, const IdType* indptr, IdType* out, int64_t length) {
142
143
144
145
146
147
148
149
150
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    const IdType vv = vid[tx];
    out[tx] = indptr[vv + 1] - indptr[vv];
    tx += stride_x;
  }
}

151
template <DGLDeviceType XPU, typename IdType>
152
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
153
  cudaStream_t stream = runtime::getCurrentCUDAStream();
154
  const auto len = rows->shape[0];
155
156
157
  const IdType* vid_data = rows.Ptr<IdType>();
  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  if (csr.is_pinned) {
158
159
    CUDA_CALL(
        cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
160
  }
161
162
  NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
  IdType* rst_data = static_cast<IdType*>(rst->data);
163
  const int nt = dgl::cuda::FindNumThreads(len);
164
  const int nb = (len + nt - 1) / nt;
165
166
167
  CUDA_KERNEL_CALL(
      _CSRGetRowNNZKernel, nb, nt, 0, stream, vid_data, indptr_data, rst_data,
      len);
168
169
170
  return rst;
}

171
172
template NDArray CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, NDArray);
173

174
////////////////////////// CSRGetRowColumnIndices //////////////////////////////
175

176
template <DGLDeviceType XPU, typename IdType>
177
178
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
179
180
  const int64_t offset =
      aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
181
182
183
  return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}

184
185
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
186
187
188

///////////////////////////// CSRGetRowData /////////////////////////////

189
template <DGLDeviceType XPU, typename IdType>
190
191
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
  const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
192
193
  const int64_t offset =
      aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
194
195
196
  if (aten::CSRHasData(csr))
    return csr.data.CreateView({len}, csr.data->dtype, offset);
  else
197
198
    return aten::Range(
        offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
199
200
}

201
202
template NDArray CSRGetRowData<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
203
204
205

///////////////////////////// CSRSliceRows /////////////////////////////

206
template <DGLDeviceType XPU, typename IdType>
207
208
209
210
211
212
213
214
215
216
217
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
  const int64_t num_rows = end - start;
  const IdType st_pos = aten::IndexSelect<IdType>(csr.indptr, start);
  const IdType ed_pos = aten::IndexSelect<IdType>(csr.indptr, end);
  const IdType nnz = ed_pos - st_pos;
  IdArray ret_indptr = aten::IndexSelect(csr.indptr, start, end + 1) - st_pos;
  // indices and data can be view arrays
  IdArray ret_indices = csr.indices.CreateView(
      {nnz}, csr.indices->dtype, st_pos * sizeof(IdType));
  IdArray ret_data;
  if (CSRHasData(csr))
218
219
    ret_data =
        csr.data.CreateView({nnz}, csr.data->dtype, st_pos * sizeof(IdType));
220
  else
221
222
223
224
    ret_data =
        aten::Range(st_pos, ed_pos, csr.indptr->dtype.bits, csr.indptr->ctx);
  return CSRMatrix(
      num_rows, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);
225
226
}

227
228
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
229

230
/**
231
 * @brief Copy data segment to output buffers
232
 *
233
234
235
 * For the i^th row r = row[i], copy the data from indptr[r] ~ indptr[r+1]
 * to the out_data from out_indptr[i] ~ out_indptr[i+1]
 *
236
237
 * If the provided `data` array is nullptr, write the read index to the
 * out_data.
238
239
240
241
 *
 */
template <typename IdType, typename DType>
__global__ void _SegmentCopyKernel(
242
243
    const IdType* indptr, const DType* data, const IdType* row, int64_t length,
    int64_t n_row, const IdType* out_indptr, DType* out_data) {
244
  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
245
246
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
247
    IdType rpos = dgl::cuda::_UpperBound(out_indptr, n_row, tx) - 1;
248
249
    IdType rofs = tx - out_indptr[rpos];
    const IdType u = row[rpos];
250
    out_data[tx] = data ? data[indptr[u] + rofs] : indptr[u] + rofs;
251
252
253
254
    tx += stride_x;
  }
}

255
template <DGLDeviceType XPU, typename IdType>
256
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
257
  cudaStream_t stream = runtime::getCurrentCUDAStream();
258
259
260
261
  const int64_t len = rows->shape[0];
  IdArray ret_indptr = aten::CumSum(aten::CSRGetRowNNZ(csr, rows), true);
  const int64_t nnz = aten::IndexSelect<IdType>(ret_indptr, len);

262
263
264
  const int nt = 256;  // for better GPU usage of small invocations
  const int nb = (nnz + nt - 1) / nt;

265
  // Copy indices.
266
  IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
267
268
269
270
271

  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  const IdType* indices_data = csr.indices.Ptr<IdType>();
  const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr;
  if (csr.is_pinned) {
272
273
274
275
    CUDA_CALL(
        cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(
        cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
276
    if (CSRHasData(csr)) {
277
278
      CUDA_CALL(
          cudaHostGetDevicePointer(&data_data, csr.data.Ptr<IdType>(), 0));
279
280
281
    }
  }

282
283
284
285
  CUDA_KERNEL_CALL(
      _SegmentCopyKernel, nb, nt, 0, stream, indptr_data, indices_data,
      rows.Ptr<IdType>(), nnz, len, ret_indptr.Ptr<IdType>(),
      ret_indices.Ptr<IdType>());
286
  // Copy data.
287
  IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
288
289
290
291
292
293
  CUDA_KERNEL_CALL(
      _SegmentCopyKernel, nb, nt, 0, stream, indptr_data, data_data,
      rows.Ptr<IdType>(), nnz, len, ret_indptr.Ptr<IdType>(),
      ret_data.Ptr<IdType>());
  return CSRMatrix(
      len, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);
294
295
}

296
297
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, NDArray);
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, NDArray);
298
299
300

///////////////////////////// CSRGetDataAndIndices /////////////////////////////

301
/**
302
 * @brief Generate a 0-1 mask for each index that hits the provided (row, col)
303
 *        index.
304
 *
305
306
307
308
309
310
311
312
313
314
315
 * Examples:
 * Given a CSR matrix (with duplicate entries) as follows:
 * [[0, 1, 2, 0, 0],
 *  [1, 0, 0, 0, 0],
 *  [0, 0, 1, 1, 0],
 *  [0, 0, 0, 0, 0]]
 * Given rows: [0, 1], cols: [0, 2, 3]
 * The result mask is: [0, 1, 1, 1, 0, 0]
 */
template <typename IdType>
__global__ void _SegmentMaskKernel(
316
317
318
    const IdType* indptr, const IdType* indices, const IdType* row,
    const IdType* col, int64_t row_stride, int64_t col_stride, int64_t length,
    IdType* mask) {
319
320
321
322
323
324
325
326
327
328
329
330
331
332
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    int rpos = tx * row_stride, cpos = tx * col_stride;
    const IdType r = row[rpos], c = col[cpos];
    for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
      if (indices[i] == c) {
        mask[i] = 1;
      }
    }
    tx += stride_x;
  }
}

333
/**
334
 * @brief Search for the insertion positions for needle in the hay.
335
336
337
338
339
 *
 * The hay is a list of sorted elements and the result is the insertion position
 * of each needle so that the insertion still gives sorted order.
 *
 * It essentially perform binary search to find lower bound for each needle
340
341
342
 * elements. Require the largest elements in the hay is larger than the given
 * needle elements. Commonly used in searching for row IDs of a given set of
 * coordinates.
343
344
345
 */
template <typename IdType>
__global__ void _SortedSearchKernel(
346
347
    const IdType* hay, int64_t hay_size, const IdType* needles,
    int64_t num_needles, IdType* pos) {
348
349
350
351
352
353
354
355
356
357
358
359
360
361
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_needles) {
    const IdType ele = needles[tx];
    // binary search
    IdType lo = 0, hi = hay_size - 1;
    while (lo < hi) {
      IdType mid = (lo + hi) >> 1;
      if (hay[mid] <= ele) {
        lo = mid + 1;
      } else {
        hi = mid;
      }
    }
362
    pos[tx] = (hay[hi] == ele) ? hi : hi - 1;
363
364
365
366
    tx += stride_x;
  }
}

367
template <DGLDeviceType XPU, typename IdType>
368
369
std::vector<NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, NDArray row, NDArray col) {
370
371
372
  const auto rowlen = row->shape[0];
  const auto collen = col->shape[0];
  const auto len = std::max(rowlen, collen);
373
  if (len == 0) return {NullArray(), NullArray(), NullArray()};
374
375
376
377
378
379

  const auto& ctx = row->ctx;
  const auto nbits = row->dtype.bits;
  const int64_t nnz = csr.indices->shape[0];
  const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
  const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
380
  cudaStream_t stream = runtime::getCurrentCUDAStream();
381

382
383
384
  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  const IdType* indices_data = csr.indices.Ptr<IdType>();
  if (csr.is_pinned) {
385
386
387
388
    CUDA_CALL(
        cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(
        cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
389
390
  }

391
392
  // Generate a 0-1 mask for matched (row, col) positions.
  IdArray mask = Full(0, nnz, nbits, ctx);
393
  const int nt = dgl::cuda::FindNumThreads(len);
394
  const int nb = (len + nt - 1) / nt;
395
396
397
  CUDA_KERNEL_CALL(
      _SegmentMaskKernel, nb, nt, 0, stream, indptr_data, indices_data,
      row.Ptr<IdType>(), col.Ptr<IdType>(), row_stride, col_stride, len,
398
399
400
401
402
403
404
405
406
      mask.Ptr<IdType>());

  IdArray idx = AsNumBits(NonZero(mask), nbits);
  if (idx->shape[0] == 0)
    // No data. Return three empty arrays.
    return {idx, idx, idx};

  // Search for row index
  IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits);
407
  const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]);
408
  const int nb2 = (idx->shape[0] + nt - 1) / nt;
409
410
411
  CUDA_KERNEL_CALL(
      _SortedSearchKernel, nb2, nt2, 0, stream, indptr_data, csr.num_rows,
      idx.Ptr<IdType>(), idx->shape[0], ret_row.Ptr<IdType>());
412
413
414

  // Column & data can be obtained by index select.
  IdArray ret_col = IndexSelect(csr.indices, idx);
415
  IdArray ret_data = CSRHasData(csr) ? IndexSelect(csr.data, idx) : idx;
416
417
418
  return {ret_row, ret_col, ret_data};
}

419
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int32_t>(
420
    CSRMatrix csr, NDArray rows, NDArray cols);
421
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>(
422
423
424
425
    CSRMatrix csr, NDArray rows, NDArray cols);

///////////////////////////// CSRSliceMatrix /////////////////////////////

426
/**
427
 * @brief Generate a 0-1 mask for each index whose column is in the provided
428
 * set. It also counts the number of masked values per row.
429
430
431
 */
template <typename IdType>
__global__ void _SegmentMaskColKernel(
432
433
434
    const IdType* indptr, const IdType* indices, int64_t num_rows,
    int64_t num_nnz, const IdType* col, int64_t col_len, IdType* mask,
    IdType* count) {
435
  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
436
  const int stride_x = gridDim.x * blockDim.x;
437
438
439
440
441
442
  while (tx < num_nnz) {
    IdType rpos = dgl::cuda::_UpperBound(indptr, num_rows, tx) - 1;
    IdType cur_c = indices[tx];
    IdType i = dgl::cuda::_BinarySearch(col, col_len, cur_c);
    if (i < col_len) {
      mask[tx] = 1;
443
      cuda::AtomicAdd(count + rpos, IdType(1));
444
445
446
447
448
    }
    tx += stride_x;
  }
}

449
template <DGLDeviceType XPU, typename IdType>
450
451
CSRMatrix CSRSliceMatrix(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
452
  cudaStream_t stream = runtime::getCurrentCUDAStream();
453
454
455
456
457
458
459
  const auto& ctx = rows->ctx;
  const auto& dtype = rows->dtype;
  const auto nbits = dtype.bits;
  const int64_t new_nrows = rows->shape[0];
  const int64_t new_ncols = cols->shape[0];

  if (new_nrows == 0 || new_ncols == 0)
460
461
462
    return CSRMatrix(
        new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),
        NullArray(dtype, ctx), NullArray(dtype, ctx));
463
464
465
466
467

  // First slice rows
  csr = CSRSliceRows(csr, rows);

  if (csr.indices->shape[0] == 0)
468
469
470
    return CSRMatrix(
        new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),
        NullArray(dtype, ctx), NullArray(dtype, ctx));
471
472
473
474
475

  // Generate a 0-1 mask for matched (row, col) positions.
  IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx);
  // A count for how many masked values per row.
  IdArray count = NewIdArray(csr.num_rows, ctx, nbits);
476
477
  CUDA_CALL(
      cudaMemset(count.Ptr<IdType>(), 0, sizeof(IdType) * (csr.num_rows)));
478
479
480
481
482
483
484
485
486
487
488
489
490

  const int64_t nnz_csr = csr.indices->shape[0];
  const int nt = 256;

  // In general ``cols'' array is sorted. But it is not guaranteed.
  // Hence checking and sorting array first. Sorting is not in place.
  auto device = runtime::DeviceAPI::Get(ctx);
  auto cols_size = cols->shape[0];

  IdArray sorted_array = NewIdArray(cols->shape[0], ctx, cols->dtype.bits);
  auto ptr_sorted_cols = sorted_array.Ptr<IdType>();
  auto ptr_cols = cols.Ptr<IdType>();
  size_t workspace_size = 0;
491
  CUDA_CALL(cub::DeviceRadixSort::SortKeys(
492
493
494
      nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], 0,
      sizeof(IdType) * 8, stream));
  void* workspace = device->AllocWorkspace(ctx, workspace_size);
495
  CUDA_CALL(cub::DeviceRadixSort::SortKeys(
496
497
      workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], 0,
      sizeof(IdType) * 8, stream));
498
499
  device->FreeWorkspace(ctx, workspace);

500
501
502
  const IdType* indptr_data = csr.indptr.Ptr<IdType>();
  const IdType* indices_data = csr.indices.Ptr<IdType>();
  if (csr.is_pinned) {
503
504
505
506
    CUDA_CALL(
        cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
    CUDA_CALL(
        cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
507
508
  }

509
510
  // Execute SegmentMaskColKernel
  int nb = (nnz_csr + nt - 1) / nt;
511
512
513
514
  CUDA_KERNEL_CALL(
      _SegmentMaskColKernel, nb, nt, 0, stream, indptr_data, indices_data,
      csr.num_rows, nnz_csr, ptr_sorted_cols, cols_size, mask.Ptr<IdType>(),
      count.Ptr<IdType>());
515
516
517

  IdArray idx = AsNumBits(NonZero(mask), nbits);
  if (idx->shape[0] == 0)
518
519
520
    return CSRMatrix(
        new_nrows, new_ncols, Full(0, new_nrows + 1, nbits, ctx),
        NullArray(dtype, ctx), NullArray(dtype, ctx));
521
522
523
524
525
526

  // Indptr needs to be adjusted according to the new nnz per row.
  IdArray ret_indptr = CumSum(count, true);

  // Column & data can be obtained by index select.
  IdArray ret_col = IndexSelect(csr.indices, idx);
527
  IdArray ret_data = CSRHasData(csr) ? IndexSelect(csr.data, idx) : idx;
528
529
530
531
532
533

  // Relabel column
  IdArray col_hash = NewIdArray(csr.num_cols, ctx, nbits);
  Scatter_(cols, Range(0, cols->shape[0], nbits, ctx), col_hash);
  ret_col = IndexSelect(col_hash, ret_col);

534
  return CSRMatrix(new_nrows, new_ncols, ret_indptr, ret_col, ret_data);
535
536
}

537
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int32_t>(
538
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
539
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int64_t>(
540
541
542
543
544
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);

}  // namespace impl
}  // namespace aten
}  // namespace dgl