csr2coo.cu 5.95 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/csr2coo.cc
 * @brief CSR2COO
5
6
 */
#include <dgl/array.h>
7
8
9
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
10

11
12
#include <cub/cub.cuh>

13
#include "../../runtime/cuda/cuda_common.h"
14
#include "./utils.h"
15
16
17
18
19
20
21
22

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

23
template <DGLDeviceType XPU, typename IdType>
24
COOMatrix CSRToCOO(CSRMatrix csr) {
25
26
27
28
29
  LOG(FATAL) << "Unreachable codes";
  return {};
}

template <>
30
COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
31
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
32
  cudaStream_t stream = runtime::getCurrentCUDAStream();
33
34
35
36
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
37
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
38
39
40

  NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;
  const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
41
42
  NDArray row =
      aten::NewIdArray(indices->shape[0], indptr->ctx, indptr->dtype.bits);
43
44
45
  int32_t* row_ptr = static_cast<int32_t*>(row->data);

  CUSPARSE_CALL(cusparseXcsr2coo(
46
47
48
49
50
      thr_entry->cusparse_handle, indptr_ptr, indices->shape[0], csr.num_rows,
      row_ptr, CUSPARSE_INDEX_BASE_ZERO));

  return COOMatrix(
      csr.num_rows, csr.num_cols, row, indices, data, true, csr.sorted);
51
52
}

53
54
55
56
struct RepeatIndex {
  template <typename IdType>
  __host__ __device__ auto operator()(IdType i) {
    return thrust::make_constant_iterator(i);
57
  }
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
};

template <typename IdType>
struct OutputBufferIndexer {
  const IdType* indptr;
  IdType* buffer;
  __host__ __device__ auto operator()(IdType i) { return buffer + indptr[i]; }
};

template <typename IdType>
struct AdjacentDifference {
  const IdType* indptr;
  __host__ __device__ auto operator()(IdType i) {
    return indptr[i + 1] - indptr[i];
  }
};
74
75

template <>
76
COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
77
  const auto& ctx = csr.indptr->ctx;
78
79
  cudaStream_t stream = runtime::getCurrentCUDAStream();

80
81
82
83
  const int64_t nnz = csr.indices->shape[0];
  const auto nbits = csr.indptr->dtype.bits;
  IdArray ret_row = NewIdArray(nnz, ctx, nbits);

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  runtime::CUDAWorkspaceAllocator allocator(csr.indptr->ctx);
  thrust::counting_iterator<int64_t> iota(0);

  auto input_buffer = thrust::make_transform_iterator(iota, RepeatIndex{});
  auto output_buffer = thrust::make_transform_iterator(
      iota, OutputBufferIndexer<int64_t>{
                csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>()});
  auto buffer_sizes = thrust::make_transform_iterator(
      iota, AdjacentDifference<int64_t>{csr.indptr.Ptr<int64_t>()});

  constexpr int64_t max_copy_at_once = std::numeric_limits<int32_t>::max();
  for (int64_t i = 0; i < csr.num_rows; i += max_copy_at_once) {
    std::size_t temp_storage_bytes = 0;
    CUDA_CALL(cub::DeviceCopy::Batched(
        nullptr, temp_storage_bytes, input_buffer + i, output_buffer + i,
        buffer_sizes + i, std::min(csr.num_rows - i, max_copy_at_once),
        stream));

    auto temp = allocator.alloc_unique<char>(temp_storage_bytes);

    CUDA_CALL(cub::DeviceCopy::Batched(
        temp.get(), temp_storage_bytes, input_buffer + i, output_buffer + i,
        buffer_sizes + i, std::min(csr.num_rows - i, max_copy_at_once),
        stream));
  }
109
110
111
112

  return COOMatrix(
      csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data, true,
      csr.sorted);
113
114
}

115
116
template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr);
117

118
template <DGLDeviceType XPU, typename IdType>
119
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
120
121
122
123
124
  LOG(FATAL) << "Unreachable codes";
  return {};
}

template <>
125
126
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
  COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr);
127
  if (aten::IsNullArray(coo.data)) return coo;
128
129
130

  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(coo.row->ctx);
131
  cudaStream_t stream = runtime::getCurrentCUDAStream();
132
133
134
135
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
136
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
137
138
139
140
141
142
143
144

  NDArray row = coo.row, col = coo.col, data = coo.data;
  int32_t* row_ptr = static_cast<int32_t*>(row->data);
  int32_t* col_ptr = static_cast<int32_t*>(col->data);
  int32_t* data_ptr = static_cast<int32_t*>(data->data);

  size_t workspace_size = 0;
  CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(
145
146
      thr_entry->cusparse_handle, coo.num_rows, coo.num_cols, row->shape[0],
      data_ptr, row_ptr, &workspace_size));
147
148
  void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
  CUSPARSE_CALL(cusparseXcoosortByRow(
149
150
      thr_entry->cusparse_handle, coo.num_rows, coo.num_cols, row->shape[0],
      data_ptr, row_ptr, col_ptr, workspace));
151
152
  device->FreeWorkspace(row->ctx, workspace);

153
154
155
  // The row and column field have already been reordered according
  // to data, thus the data field will be deprecated.
  coo.data = aten::NullArray();
156
157
158
159
160
161
  coo.row_sorted = false;
  coo.col_sorted = false;
  return coo;
}

template <>
162
163
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) {
  COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr);
164
  if (aten::IsNullArray(coo.data)) return coo;
165
166
167
168
169
170
171
172
173
174
  const auto& sorted = Sort(coo.data);

  coo.row = IndexSelect(coo.row, sorted.second);
  coo.col = IndexSelect(coo.col, sorted.second);

  // The row and column field have already been reordered according
  // to data, thus the data field will be deprecated.
  coo.data = aten::NullArray();
  coo.row_sorted = false;
  coo.col_sorted = false;
175
176
177
  return coo;
}

178
179
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr);
180
181
182
183

}  // namespace impl
}  // namespace aten
}  // namespace dgl