csr_sort.cu 4.79 KB
Newer Older
1
2
3
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/csr_sort.cc
4
 * \brief Sort CSR index
5
6
7
8
 */
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
9
#include "./dgl_cub.cuh"
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

/*!
 * \brief Check whether each row is sorted.
 */
template <typename IdType>
__global__ void _SegmentIsSorted(
    const IdType* indptr, const IdType* indices,
    int64_t num_rows, int8_t* flags) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_rows) {
    bool f = true;
    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {
      f = (indices[i - 1] <= indices[i]);
    }
    flags[tx] = static_cast<int8_t>(f);
    tx += stride_x;
  }
}

template <DLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) {
  const auto& ctx = csr.indptr->ctx;
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(ctx);
  // We allocate a workspace of num_rows bytes. It wastes a little bit memory but should
  // be fine.
  int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
  const int nt = cuda::FindNumThreads(csr.num_rows);
  const int nb = (csr.num_rows + nt - 1) / nt;
47
48
  CUDA_KERNEL_CALL(_SegmentIsSorted,
      nb, nt, 0, thr_entry->stream,
49
50
51
52
53
54
55
56
57
58
59
60
      csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
      csr.num_rows, flags);
  bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);
  device->FreeWorkspace(ctx, flags);
  return ret;
}

template bool CSRIsSorted<kDLGPU, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLGPU, int64_t>(CSRMatrix csr);

template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
61
62
63
64
65
  LOG(FATAL) << "Unreachable codes";
}

template <>
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));

  NDArray indptr = csr->indptr;
  NDArray indices = csr->indices;
  const auto& ctx = indptr->ctx;
  const int64_t nnz = indices->shape[0];
  if (!aten::CSRHasData(*csr))
    csr->data = aten::Range(0, nnz, indices->dtype.bits, ctx);
  NDArray data = csr->data;

  size_t workspace_size = 0;
  CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt(
      thr_entry->cusparse_handle,
      csr->num_rows, csr->num_cols, nnz,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
      &workspace_size));
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  cusparseMatDescr_t descr;
  CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
  CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
  CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
  CUSPARSE_CALL(cusparseXcsrsort(
      thr_entry->cusparse_handle,
      csr->num_rows, csr->num_cols, nnz,
      descr,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
      data.Ptr<int32_t>(),
      workspace));

  csr->sorted = true;

  // free resources
  CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
  device->FreeWorkspace(ctx, workspace);
}

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
template <>
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);

  const auto& ctx = csr->indptr->ctx;
  const int64_t nnz = csr->indices->shape[0];
  const auto nbits = csr->indptr->dtype.bits;
  if (!aten::CSRHasData(*csr))
    csr->data = aten::Range(0, nnz, nbits, ctx);

  IdArray new_indices = csr->indices.Clone();
  IdArray new_data = csr->data.Clone();

  const int64_t* offsets = csr->indptr.Ptr<int64_t>();
  const int64_t* key_in = csr->indices.Ptr<int64_t>();
  int64_t* key_out = new_indices.Ptr<int64_t>();
  const int64_t* value_in = csr->data.Ptr<int64_t>();
  int64_t* value_out = new_data.Ptr<int64_t>();

  // Allocate workspace
  size_t workspace_size = 0;
  cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size,
      key_in, key_out, value_in, value_out,
      nnz, csr->num_rows, offsets, offsets + 1);
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  // Compute
  cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size,
      key_in, key_out, value_in, value_out,
      nnz, csr->num_rows, offsets, offsets + 1);

  csr->sorted = true;
  csr->indices = new_indices;
  csr->data = new_data;
144
145
146

  // free resources
  device->FreeWorkspace(ctx, workspace);
147
148
}

149
150
151
152
153
154
template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr);

}  // namespace impl
}  // namespace aten
}  // namespace dgl