coo2csr.cu 4.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include "hip/hip_runtime.h"
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/coo2csr.cc
 * \brief COO2CSR
 */
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

template <DLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) {
  LOG(FATAL) << "Unreachable code.";
  return {};
}

template <>
lisj's avatar
lisj committed
25
CSRMatrix COOToCSR<kDLROCM, int32_t>(COOMatrix coo) {
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  hipStream_t stream = runtime::getCurrentCUDAStream();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
  }
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));

  bool row_sorted = coo.row_sorted;
  bool col_sorted = coo.col_sorted;
  if (!row_sorted) {
    // we only need to sort the rows to perform conversion
    coo = COOSort(coo, false);
    col_sorted = coo.col_sorted;
  }

  const int64_t nnz = coo.row->shape[0];
  // TODO(minjie): Many of our current implementation assumes that CSR must have
  //   a data array. This is a temporary workaround. Remove this after:
  //   - The old immutable graph implementation is deprecated.
  //   - The old binary reduce kernel is deprecated.
  if (!COOHasData(coo))
    coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);

  NDArray indptr = aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits);
  int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
  CUSPARSE_CALL(hipsparseXcoo2csr(
        thr_entry->cusparse_handle,
        coo.row.Ptr<int32_t>(),
        nnz,
        coo.num_rows,
        indptr_ptr,
        HIPSPARSE_INDEX_BASE_ZERO));

  return CSRMatrix(coo.num_rows, coo.num_cols,
                   indptr, coo.col, coo.data, col_sorted);
}

/*!
 * \brief Search for the insertion positions for needle in the hay.
 *
 * The hay is a list of sorted elements and the result is the insertion position
 * of each needle so that the insertion still gives sorted order.
 *
 * It essentially perform binary search to find upper bound for each needle
 * elements.
 *
 * For example:
 * hay = [0, 0, 1, 2, 2]
 * needle = [0, 1, 2, 3]
 * then,
 * out = [2, 3, 5, 5]
 */
template <typename IdType>
__global__ void _SortedSearchKernelUpperBound(
    const IdType* hay, int64_t hay_size,
    const IdType* needles, int64_t num_needles,
    IdType* pos) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_needles) {
    const IdType ele = needles[tx];
    // binary search
    IdType lo = 0, hi = hay_size;
    while (lo < hi) {
      IdType mid = (lo + hi) >> 1;
      if (hay[mid] <= ele) {
        lo = mid + 1;
      } else {
        hi = mid;
      }
    }
    pos[tx] = lo;
    tx += stride_x;
  }
}

template <>
lisj's avatar
lisj committed
104
CSRMatrix COOToCSR<kDLROCM, int64_t>(COOMatrix coo) {
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  const auto& ctx = coo.row->ctx;
  const auto nbits = coo.row->dtype.bits;
  hipStream_t stream = runtime::getCurrentCUDAStream();
  bool row_sorted = coo.row_sorted;
  bool col_sorted = coo.col_sorted;
  if (!row_sorted) {
    coo = COOSort(coo, false);
    col_sorted = coo.col_sorted;
  }

  const int64_t nnz = coo.row->shape[0];
  // TODO(minjie): Many of our current implementation assumes that CSR must have
  //   a data array. This is a temporary workaround. Remove this after:
  //   - The old immutable graph implementation is deprecated.
  //   - The old binary reduce kernel is deprecated.
  if (!COOHasData(coo))
    coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);

  IdArray rowids = Range(0, coo.num_rows, nbits, ctx);
  const int nt = cuda::FindNumThreads(coo.num_rows);
  const int nb = (coo.num_rows + nt - 1) / nt;
  IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx);
  CUDA_KERNEL_CALL(_SortedSearchKernelUpperBound,
      nb, nt, 0, stream,
      coo.row.Ptr<int64_t>(), nnz,
      rowids.Ptr<int64_t>(), coo.num_rows,
      indptr.Ptr<int64_t>() + 1);

  return CSRMatrix(coo.num_rows, coo.num_cols,
                   indptr, coo.col, coo.data, col_sorted);
}

lisj's avatar
lisj committed
137
138
template CSRMatrix COOToCSR<kDLROCM, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLROCM, int64_t>(COOMatrix coo);
139
140
141
142

}  // namespace impl
}  // namespace aten
}  // namespace dgl