csr_sort.cu 4.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#include "hip/hip_runtime.h"
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/csr_sort.cc
 * \brief Sort CSR index
 */
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

/*!
 * \brief Check whether each row is sorted.
 */
template <typename IdType>
__global__ void _SegmentIsSorted(
    const IdType* indptr, const IdType* indices,
    int64_t num_rows, int8_t* flags) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_rows) {
    bool f = true;
    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {
      f = (indices[i - 1] <= indices[i]);
    }
    flags[tx] = static_cast<int8_t>(f);
    tx += stride_x;
  }
}

template <DLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) {
  const auto& ctx = csr.indptr->ctx;
  hipStream_t stream = runtime::getCurrentCUDAStream();
  auto device = runtime::DeviceAPI::Get(ctx);
  // We allocate a workspace of num_rows bytes. It wastes a little bit memory but should
  // be fine.
  int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
  const int nt = cuda::FindNumThreads(csr.num_rows);
  const int nb = (csr.num_rows + nt - 1) / nt;
  CUDA_KERNEL_CALL(_SegmentIsSorted,
      nb, nt, 0, stream,
      csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
      csr.num_rows, flags);
  bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);
  device->FreeWorkspace(ctx, flags);
  return ret;
}

template bool CSRIsSorted<kDLGPU, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLGPU, int64_t>(CSRMatrix csr);

template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
  LOG(FATAL) << "Unreachable codes";
}

template <>
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
  hipStream_t stream = runtime::getCurrentCUDAStream();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
  }
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));

  NDArray indptr = csr->indptr;
  NDArray indices = csr->indices;
  const auto& ctx = indptr->ctx;
  const int64_t nnz = indices->shape[0];
  if (!aten::CSRHasData(*csr))
    csr->data = aten::Range(0, nnz, indices->dtype.bits, ctx);
  NDArray data = csr->data;

  size_t workspace_size = 0;
  CUSPARSE_CALL(hipsparseXcsrsort_bufferSizeExt(
      thr_entry->cusparse_handle,
      csr->num_rows, csr->num_cols, nnz,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
      &workspace_size));
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  hipsparseMatDescr_t descr;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&descr));
  CUSPARSE_CALL(hipsparseSetMatType(descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
  CUSPARSE_CALL(hipsparseSetMatIndexBase(descr, HIPSPARSE_INDEX_BASE_ZERO));
  CUSPARSE_CALL(hipsparseXcsrsort(
      thr_entry->cusparse_handle,
      csr->num_rows, csr->num_cols, nnz,
      descr,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
      data.Ptr<int32_t>(),
      workspace));

  csr->sorted = true;

  // free resources
  CUSPARSE_CALL(hipsparseDestroyMatDescr(descr));
  device->FreeWorkspace(ctx, workspace);
}

template <>
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
  hipStream_t stream = runtime::getCurrentCUDAStream();
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);

  const auto& ctx = csr->indptr->ctx;
  const int64_t nnz = csr->indices->shape[0];
  const auto nbits = csr->indptr->dtype.bits;
  if (!aten::CSRHasData(*csr))
    csr->data = aten::Range(0, nnz, nbits, ctx);

  IdArray new_indices = csr->indices.Clone();
  IdArray new_data = csr->data.Clone();

  const int64_t* offsets = csr->indptr.Ptr<int64_t>();
  const int64_t* key_in = csr->indices.Ptr<int64_t>();
  int64_t* key_out = new_indices.Ptr<int64_t>();
  const int64_t* value_in = csr->data.Ptr<int64_t>();
  int64_t* value_out = new_data.Ptr<int64_t>();

  // Allocate workspace
  size_t workspace_size = 0;
  CUDA_CALL(hipcub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size,
      key_in, key_out, value_in, value_out,
      nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream));
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  // Compute
  CUDA_CALL(hipcub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size,
      key_in, key_out, value_in, value_out,
      nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream));

  csr->sorted = true;
  csr->indices = new_indices;
  csr->data = new_data;

  // free resources
  device->FreeWorkspace(ctx, workspace);
}

template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr);

}  // namespace impl
}  // namespace aten
}  // namespace dgl