coo2csr.cu 4.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/coo2csr.cc
 * \brief COO2CSR
 */
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

template <DLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) {
  LOG(FATAL) << "Unreachable code.";
  return {};
}

template <>
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) {
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));

  bool row_sorted = coo.row_sorted;
  bool col_sorted = coo.col_sorted;
  if (!row_sorted) {
    // It is possible that the flag is simply not set (default value is false),
    // so we still perform a linear scan to check the flag.
    std::tie(row_sorted, col_sorted) = COOIsSorted(coo);
  }
  if (!row_sorted) {
    coo = COOSort(coo);
    col_sorted = coo.col_sorted;
  }

  const int64_t nnz = coo.row->shape[0];
  // TODO(minjie): Many of our current implementation assumes that CSR must have
  //   a data array. This is a temporary workaround. Remove this after:
  //   - The old immutable graph implementation is deprecated.
  //   - The old binary reduce kernel is deprecated.
  if (!COOHasData(coo))
    coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);

  NDArray indptr = aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits);
  int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
  CUSPARSE_CALL(cusparseXcoo2csr(
        thr_entry->cusparse_handle,
        coo.row.Ptr<int32_t>(),
        nnz,
        coo.num_rows,
        indptr_ptr,
        CUSPARSE_INDEX_BASE_ZERO));

  return CSRMatrix(coo.num_rows, coo.num_cols,
                   indptr, coo.col, coo.data, col_sorted);
}

/*!
 * \brief Search for the insertion positions for needle in the hay.
 *
 * The hay is a list of sorted elements and the result is the insertion position
 * of each needle so that the insertion still gives sorted order.
 *
 * It essentially perform binary search to find upper bound for each needle
 * elements.
 *
 * For example:
 * hay = [0, 0, 1, 2, 2]
 * needle = [0, 1, 2, 3]
 * then,
 * out = [2, 3, 5, 5]
 */
template <typename IdType>
__global__ void _SortedSearchKernelUpperBound(
    const IdType* hay, int64_t hay_size,
    const IdType* needles, int64_t num_needles,
    IdType* pos) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_needles) {
    const IdType ele = needles[tx];
    // binary search
    IdType lo = 0, hi = hay_size;
    while (lo < hi) {
      IdType mid = (lo + hi) >> 1;
      if (hay[mid] <= ele) {
        lo = mid + 1;
      } else {
        hi = mid;
      }
    }
    pos[tx] = lo;
    tx += stride_x;
  }
}

template <>
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
  const auto& ctx = coo.row->ctx;
  const auto nbits = coo.row->dtype.bits;
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  bool row_sorted = coo.row_sorted;
  bool col_sorted = coo.col_sorted;
  if (!row_sorted) {
    // It is possible that the flag is simply not set (default value is false),
    // so we still perform a linear scan to check the flag.
    std::tie(row_sorted, col_sorted) = COOIsSorted(coo);
  }
  if (!row_sorted) {
    coo = COOSort(coo);
    col_sorted = coo.col_sorted;
  }

  const int64_t nnz = coo.row->shape[0];
  // TODO(minjie): Many of our current implementation assumes that CSR must have
  //   a data array. This is a temporary workaround. Remove this after:
  //   - The old immutable graph implementation is deprecated.
  //   - The old binary reduce kernel is deprecated.
  if (!COOHasData(coo))
    coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);

  IdArray rowids = Range(0, coo.num_rows, nbits, ctx);
  const int nt = cuda::FindNumThreads(coo.num_rows);
  const int nb = (coo.num_rows + nt - 1) / nt;
  IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx);
134
135
  CUDA_KERNEL_CALL(_SortedSearchKernelUpperBound,
      nb, nt, 0, thr_entry->stream,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
      coo.row.Ptr<int64_t>(), nnz,
      rowids.Ptr<int64_t>(), coo.num_rows,
      indptr.Ptr<int64_t>() + 1);

  return CSRMatrix(coo.num_rows, coo.num_cols,
                   indptr, coo.col, coo.data, col_sorted);
}

template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo);

}  // namespace impl
}  // namespace aten
}  // namespace dgl