csr_sort.cu 4.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include "hip/hip_runtime.h"
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/csr_sort.cc
 * \brief Sort CSR index
 */
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

/*!
 * \brief Check whether each row is sorted.
 */
template <typename IdType>
__global__ void _SegmentIsSorted(
    const IdType* indptr, const IdType* indices,
    int64_t num_rows, int8_t* flags) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_rows) {
    bool f = true;
    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {
      f = (indices[i - 1] <= indices[i]);
    }
    flags[tx] = static_cast<int8_t>(f);
    tx += stride_x;
  }
}

template <DLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) {
  const auto& ctx = csr.indptr->ctx;
  hipStream_t stream = runtime::getCurrentCUDAStream();
  auto device = runtime::DeviceAPI::Get(ctx);
  // We allocate a workspace of num_rows bytes. It wastes a little bit memory but should
  // be fine.
  int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
  const int nt = cuda::FindNumThreads(csr.num_rows);
  const int nb = (csr.num_rows + nt - 1) / nt;
  CUDA_KERNEL_CALL(_SegmentIsSorted,
      nb, nt, 0, stream,
      csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
      csr.num_rows, flags);
  bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);
  device->FreeWorkspace(ctx, flags);
  return ret;
}

lisj's avatar
lisj committed
57
58
template bool CSRIsSorted<kDLROCM, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLROCM, int64_t>(CSRMatrix csr);
59
60
61
62
63
64
65

template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
  LOG(FATAL) << "Unreachable codes";
}

template <>
lisj's avatar
lisj committed
66
void CSRSort_<kDLROCM, int32_t>(CSRMatrix* csr) {
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
  hipStream_t stream = runtime::getCurrentCUDAStream();
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
  }
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));

  NDArray indptr = csr->indptr;
  NDArray indices = csr->indices;
  const auto& ctx = indptr->ctx;
  const int64_t nnz = indices->shape[0];
  if (!aten::CSRHasData(*csr))
    csr->data = aten::Range(0, nnz, indices->dtype.bits, ctx);
  NDArray data = csr->data;

  size_t workspace_size = 0;
  CUSPARSE_CALL(hipsparseXcsrsort_bufferSizeExt(
      thr_entry->cusparse_handle,
      csr->num_rows, csr->num_cols, nnz,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
      &workspace_size));
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  hipsparseMatDescr_t descr;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&descr));
  CUSPARSE_CALL(hipsparseSetMatType(descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
  CUSPARSE_CALL(hipsparseSetMatIndexBase(descr, HIPSPARSE_INDEX_BASE_ZERO));
  CUSPARSE_CALL(hipsparseXcsrsort(
      thr_entry->cusparse_handle,
      csr->num_rows, csr->num_cols, nnz,
      descr,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
      data.Ptr<int32_t>(),
      workspace));

  csr->sorted = true;

  // free resources
  CUSPARSE_CALL(hipsparseDestroyMatDescr(descr));
  device->FreeWorkspace(ctx, workspace);
}

template <>
lisj's avatar
lisj committed
112
void CSRSort_<kDLROCM, int64_t>(CSRMatrix* csr) {
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  hipStream_t stream = runtime::getCurrentCUDAStream();
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);

  const auto& ctx = csr->indptr->ctx;
  const int64_t nnz = csr->indices->shape[0];
  const auto nbits = csr->indptr->dtype.bits;
  if (!aten::CSRHasData(*csr))
    csr->data = aten::Range(0, nnz, nbits, ctx);

  IdArray new_indices = csr->indices.Clone();
  IdArray new_data = csr->data.Clone();

  const int64_t* offsets = csr->indptr.Ptr<int64_t>();
  const int64_t* key_in = csr->indices.Ptr<int64_t>();
  int64_t* key_out = new_indices.Ptr<int64_t>();
  const int64_t* value_in = csr->data.Ptr<int64_t>();
  int64_t* value_out = new_data.Ptr<int64_t>();

  // Allocate workspace
  size_t workspace_size = 0;
  CUDA_CALL(hipcub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size,
      key_in, key_out, value_in, value_out,
      nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream));
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  // Compute
  CUDA_CALL(hipcub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size,
      key_in, key_out, value_in, value_out,
      nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream));

  csr->sorted = true;
  csr->indices = new_indices;
  csr->data = new_data;

  // free resources
  device->FreeWorkspace(ctx, workspace);
}

lisj's avatar
lisj committed
151
152
template void CSRSort_<kDLROCM, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLROCM, int64_t>(CSRMatrix* csr);
153
154
155
156

}  // namespace impl
}  // namespace aten
}  // namespace dgl