coo2csr.cu 4.2 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/coo2csr.cc
 * \brief COO2CSR
 */
#include <dgl/array.h>
7

8
9
10
11
12
13
14
15
16
17
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

18
template <DGLDeviceType XPU, typename IdType>
19
20
21
22
23
24
CSRMatrix COOToCSR(COOMatrix coo) {
  LOG(FATAL) << "Unreachable code.";
  return {};
}

template <>
25
CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
26
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
27
  cudaStream_t stream = runtime::getCurrentCUDAStream();
28
29
30
31
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
32
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
33
34
35
36

  bool row_sorted = coo.row_sorted;
  bool col_sorted = coo.col_sorted;
  if (!row_sorted) {
37
38
    // we only need to sort the rows to perform conversion
    coo = COOSort(coo, false);
39
40
41
42
43
44
45
46
47
48
49
    col_sorted = coo.col_sorted;
  }

  const int64_t nnz = coo.row->shape[0];
  // TODO(minjie): Many of our current implementation assumes that CSR must have
  //   a data array. This is a temporary workaround. Remove this after:
  //   - The old immutable graph implementation is deprecated.
  //   - The old binary reduce kernel is deprecated.
  if (!COOHasData(coo))
    coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);

50
51
  NDArray indptr =
      aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits);
52
53
  int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
  CUSPARSE_CALL(cusparseXcoo2csr(
54
55
56
57
58
      thr_entry->cusparse_handle, coo.row.Ptr<int32_t>(), nnz, coo.num_rows,
      indptr_ptr, CUSPARSE_INDEX_BASE_ZERO));

  return CSRMatrix(
      coo.num_rows, coo.num_cols, indptr, coo.col, coo.data, col_sorted);
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
}

/*!
 * \brief Search for the insertion positions for needle in the hay.
 *
 * The hay is a list of sorted elements and the result is the insertion position
 * of each needle so that the insertion still gives sorted order.
 *
 * It essentially perform binary search to find upper bound for each needle
 * elements.
 *
 * For example:
 * hay = [0, 0, 1, 2, 2]
 * needle = [0, 1, 2, 3]
 * then,
 * out = [2, 3, 5, 5]
 */
template <typename IdType>
__global__ void _SortedSearchKernelUpperBound(
78
79
    const IdType* hay, int64_t hay_size, const IdType* needles,
    int64_t num_needles, IdType* pos) {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_needles) {
    const IdType ele = needles[tx];
    // binary search
    IdType lo = 0, hi = hay_size;
    while (lo < hi) {
      IdType mid = (lo + hi) >> 1;
      if (hay[mid] <= ele) {
        lo = mid + 1;
      } else {
        hi = mid;
      }
    }
    pos[tx] = lo;
    tx += stride_x;
  }
}

template <>
100
CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) {
101
102
  const auto& ctx = coo.row->ctx;
  const auto nbits = coo.row->dtype.bits;
103
  cudaStream_t stream = runtime::getCurrentCUDAStream();
104
105
106
  bool row_sorted = coo.row_sorted;
  bool col_sorted = coo.col_sorted;
  if (!row_sorted) {
107
    coo = COOSort(coo, false);
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    col_sorted = coo.col_sorted;
  }

  const int64_t nnz = coo.row->shape[0];
  // TODO(minjie): Many of our current implementation assumes that CSR must have
  //   a data array. This is a temporary workaround. Remove this after:
  //   - The old immutable graph implementation is deprecated.
  //   - The old binary reduce kernel is deprecated.
  if (!COOHasData(coo))
    coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);

  IdArray rowids = Range(0, coo.num_rows, nbits, ctx);
  const int nt = cuda::FindNumThreads(coo.num_rows);
  const int nb = (coo.num_rows + nt - 1) / nt;
  IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx);
123
124
125
126
127
128
  CUDA_KERNEL_CALL(
      _SortedSearchKernelUpperBound, nb, nt, 0, stream, coo.row.Ptr<int64_t>(),
      nnz, rowids.Ptr<int64_t>(), coo.num_rows, indptr.Ptr<int64_t>() + 1);

  return CSRMatrix(
      coo.num_rows, coo.num_cols, indptr, coo.col, coo.data, col_sorted);
129
130
}

131
132
template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo);
133
134
135
136

}  // namespace impl
}  // namespace aten
}  // namespace dgl