coo2csr.cc 2.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/coo2csr.cc
 * \brief COO2CSR
 */
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

template <DLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) {
  CHECK(sizeof(IdType) == 4) << "CUDA COOToCSR does not support int64.";
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));

26
27
28
29
30
31
32
33
34
  bool row_sorted = coo.row_sorted;
  bool col_sorted = coo.col_sorted;
  if (!row_sorted) {
    // It is possible that the flag is simply not set (default value is false),
    // so we still perform a linear scan to check the flag.
    std::tie(row_sorted, col_sorted) = COOIsSorted(coo);
  }
  if (!row_sorted) {
    coo = COOSort(coo);
35
36
  }

37
38
39
40
41
42
43
44
45
  const int64_t nnz = coo.row->shape[0];
  // TODO(minjie): Many of our current implementation assumes that CSR must have
  //   a data array. This is a temporary workaround. Remove this after:
  //   - The old immutable graph implementation is deprecated.
  //   - The old binary reduce kernel is deprecated.
  if (!COOHasData(coo))
    coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);

  NDArray indptr = aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits);
46
47
48
  int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
  CUSPARSE_CALL(cusparseXcoo2csr(
        thr_entry->cusparse_handle,
49
50
        coo.row.Ptr<int32_t>(),
        nnz,
51
52
53
54
55
        coo.num_rows,
        indptr_ptr,
        CUSPARSE_INDEX_BASE_ZERO));

  return CSRMatrix(coo.num_rows, coo.num_cols,
56
                   indptr, coo.col, coo.data, col_sorted);
57
58
59
60
61
62
63
64
65
}

template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo);


}  // namespace impl
}  // namespace aten
}  // namespace dgl