csr_sort.cu 4.89 KB
Newer Older
1
2
3
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/csr_sort.cc
4
 * \brief Sort CSR index
5
6
 */
#include <dgl/array.h>
7

8
#include "../../runtime/cuda/cuda_common.h"
9
#include "./dgl_cub.cuh"
10
#include "./utils.h"
11
12
13
14
15
16
17
18
19
20
21
22
23

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

/*!
 * \brief Check whether each row is sorted.
 */
template <typename IdType>
__global__ void _SegmentIsSorted(
24
25
    const IdType* indptr, const IdType* indices, int64_t num_rows,
    int8_t* flags) {
26
27
28
29
30
31
32
33
34
35
36
37
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_rows) {
    bool f = true;
    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {
      f = (indices[i - 1] <= indices[i]);
    }
    flags[tx] = static_cast<int8_t>(f);
    tx += stride_x;
  }
}

38
template <DGLDeviceType XPU, typename IdType>
39
40
bool CSRIsSorted(CSRMatrix csr) {
  const auto& ctx = csr.indptr->ctx;
41
  cudaStream_t stream = runtime::getCurrentCUDAStream();
42
  auto device = runtime::DeviceAPI::Get(ctx);
43
44
45
46
  // We allocate a workspace of num_rows bytes. It wastes a little bit memory
  // but should be fine.
  int8_t* flags =
      static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
47
48
  const int nt = cuda::FindNumThreads(csr.num_rows);
  const int nb = (csr.num_rows + nt - 1) / nt;
49
50
51
  CUDA_KERNEL_CALL(
      _SegmentIsSorted, nb, nt, 0, stream, csr.indptr.Ptr<IdType>(),
      csr.indices.Ptr<IdType>(), csr.num_rows, flags);
52
53
54
55
56
  bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);
  device->FreeWorkspace(ctx, flags);
  return ret;
}

57
58
template bool CSRIsSorted<kDGLCUDA, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCUDA, int64_t>(CSRMatrix csr);
59

60
template <DGLDeviceType XPU, typename IdType>
61
void CSRSort_(CSRMatrix* csr) {
62
63
64
65
  LOG(FATAL) << "Unreachable codes";
}

template <>
66
void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
67
68
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
69
  cudaStream_t stream = runtime::getCurrentCUDAStream();
70
71
72
73
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
74
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
75
76
77
78
79
80
81
82
83
84
85

  NDArray indptr = csr->indptr;
  NDArray indices = csr->indices;
  const auto& ctx = indptr->ctx;
  const int64_t nnz = indices->shape[0];
  if (!aten::CSRHasData(*csr))
    csr->data = aten::Range(0, nnz, indices->dtype.bits, ctx);
  NDArray data = csr->data;

  size_t workspace_size = 0;
  CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt(
86
87
      thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), &workspace_size));
88
89
90
91
92
93
94
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  cusparseMatDescr_t descr;
  CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
  CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
  CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
  CUSPARSE_CALL(cusparseXcsrsort(
95
96
      thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz, descr,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), data.Ptr<int32_t>(),
97
98
99
100
101
102
103
104
105
      workspace));

  csr->sorted = true;

  // free resources
  CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
  device->FreeWorkspace(ctx, workspace);
}

106
template <>
107
void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
108
  cudaStream_t stream = runtime::getCurrentCUDAStream();
109
110
111
112
113
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);

  const auto& ctx = csr->indptr->ctx;
  const int64_t nnz = csr->indices->shape[0];
  const auto nbits = csr->indptr->dtype.bits;
114
  if (!aten::CSRHasData(*csr)) csr->data = aten::Range(0, nnz, nbits, ctx);
115
116
117
118
119
120
121
122
123
124
125
126

  IdArray new_indices = csr->indices.Clone();
  IdArray new_data = csr->data.Clone();

  const int64_t* offsets = csr->indptr.Ptr<int64_t>();
  const int64_t* key_in = csr->indices.Ptr<int64_t>();
  int64_t* key_out = new_indices.Ptr<int64_t>();
  const int64_t* value_in = csr->data.Ptr<int64_t>();
  int64_t* value_out = new_data.Ptr<int64_t>();

  // Allocate workspace
  size_t workspace_size = 0;
127
128
129
  CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(
      nullptr, workspace_size, key_in, key_out, value_in, value_out, nnz,
      csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));
130
131
132
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  // Compute
133
134
135
  CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(
      workspace, workspace_size, key_in, key_out, value_in, value_out, nnz,
      csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));
136
137
138
139

  csr->sorted = true;
  csr->indices = new_indices;
  csr->data = new_data;
140
141
142

  // free resources
  device->FreeWorkspace(ctx, workspace);
143
144
}

145
146
template void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr);
147
148
149
150

}  // namespace impl
}  // namespace aten
}  // namespace dgl