csr2coo.cu 5.38 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/csr2coo.cc
 * \brief CSR2COO
 */
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
8
#include "./utils.h"
9
10
11
12
13
14
15
16

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

17
template <DGLDeviceType XPU, typename IdType>
18
COOMatrix CSRToCOO(CSRMatrix csr) {
19
20
21
22
23
  LOG(FATAL) << "Unreachable codes";
  return {};
}

template <>
24
COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
25
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
26
  cudaStream_t stream = runtime::getCurrentCUDAStream();
27
28
29
30
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
31
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

  NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;
  const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
  NDArray row = aten::NewIdArray(indices->shape[0], indptr->ctx, indptr->dtype.bits);
  int32_t* row_ptr = static_cast<int32_t*>(row->data);

  CUSPARSE_CALL(cusparseXcsr2coo(
      thr_entry->cusparse_handle,
      indptr_ptr,
      indices->shape[0],
      csr.num_rows,
      row_ptr,
      CUSPARSE_INDEX_BASE_ZERO));

  return COOMatrix(csr.num_rows, csr.num_cols,
                   row, indices, data,
                   true, csr.sorted);
}

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
/*!
 * \brief Repeat elements
 * \param val Value to repeat
 * \param repeats Number of repeats for each value
 * \param pos The position of the output buffer to write the value.
 * \param out Output buffer.
 * \param length Number of values
 *
 * For example:
 * val = [3, 0, 1]
 * repeats = [1, 0, 2]
 * pos = [0, 1, 1]  # write to output buffer position 0, 1, 1
 * then,
 * out = [3, 1, 1]
 */
template <typename DType, typename IdType>
__global__ void _RepeatKernel(
68
69
    const DType* val, const IdType* pos,
    DType* out, int64_t n_row, int64_t length) {
70
  IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
71
72
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
73
74
    IdType i = dgl::cuda::_UpperBound(pos, n_row, tx) - 1;
    out[tx] = val[i];
75
76
77
78
79
    tx += stride_x;
  }
}

template <>
80
COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
81
  const auto& ctx = csr.indptr->ctx;
82
83
  cudaStream_t stream = runtime::getCurrentCUDAStream();

84
85
86
87
88
  const int64_t nnz = csr.indices->shape[0];
  const auto nbits = csr.indptr->dtype.bits;
  IdArray rowids = Range(0, csr.num_rows, nbits, ctx);
  IdArray ret_row = NewIdArray(nnz, ctx, nbits);

89
90
  const int nt = 256;
  const int nb = (nnz + nt - 1) / nt;
91
  CUDA_KERNEL_CALL(_RepeatKernel,
92
      nb, nt, 0, stream,
93
      rowids.Ptr<int64_t>(),
94
      csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(),
95
      csr.num_rows, nnz);
96
97
98
99
100
101

  return COOMatrix(csr.num_rows, csr.num_cols,
                   ret_row, csr.indices, csr.data,
                   true, csr.sorted);
}

102
103
template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr);
104

105
template <DGLDeviceType XPU, typename IdType>
106
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
107
108
109
110
111
  LOG(FATAL) << "Unreachable codes";
  return {};
}

template <>
112
113
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
  COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr);
114
115
116
117
118
  if (aten::IsNullArray(coo.data))
    return coo;

  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(coo.row->ctx);
119
  cudaStream_t stream = runtime::getCurrentCUDAStream();
120
121
122
123
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
124
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

  NDArray row = coo.row, col = coo.col, data = coo.data;
  int32_t* row_ptr = static_cast<int32_t*>(row->data);
  int32_t* col_ptr = static_cast<int32_t*>(col->data);
  int32_t* data_ptr = static_cast<int32_t*>(data->data);

  size_t workspace_size = 0;
  CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(
      thr_entry->cusparse_handle,
      coo.num_rows, coo.num_cols,
      row->shape[0],
      data_ptr,
      row_ptr,
      &workspace_size));
  void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
  CUSPARSE_CALL(cusparseXcoosortByRow(
      thr_entry->cusparse_handle,
      coo.num_rows, coo.num_cols,
      row->shape[0],
      data_ptr,
      row_ptr,
      col_ptr,
      workspace));
  device->FreeWorkspace(row->ctx, workspace);

150
151
152
  // The row and column field have already been reordered according
  // to data, thus the data field will be deprecated.
  coo.data = aten::NullArray();
153
154
155
156
157
158
  coo.row_sorted = false;
  coo.col_sorted = false;
  return coo;
}

template <>
159
160
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) {
  COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr);
161
162
163
164
165
166
167
168
169
170
171
172
  if (aten::IsNullArray(coo.data))
    return coo;
  const auto& sorted = Sort(coo.data);

  coo.row = IndexSelect(coo.row, sorted.second);
  coo.col = IndexSelect(coo.col, sorted.second);

  // The row and column field have already been reordered according
  // to data, thus the data field will be deprecated.
  coo.data = aten::NullArray();
  coo.row_sorted = false;
  coo.col_sorted = false;
173
174
175
  return coo;
}

176
177
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr);
178
179
180
181

}  // namespace impl
}  // namespace aten
}  // namespace dgl