csr_sort.hip 5.04 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/csr_sort.cc
 * @brief Sort CSR index
7
8
 */
#include <dgl/array.h>
9

sangwzh's avatar
sangwzh committed
10
#include <hipcub/hipcub.hpp>
11

12
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
13
#include "utils.h"
14
15
16
17
18
19
20
21

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

22
/**
23
 * @brief Check whether each row is sorted.
24
25
26
 */
template <typename IdType>
__global__ void _SegmentIsSorted(
27
28
    const IdType* indptr, const IdType* indices, int64_t num_rows,
    int8_t* flags) {
29
30
31
32
33
34
35
36
37
38
39
40
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < num_rows) {
    bool f = true;
    for (IdType i = indptr[tx] + 1; f && i < indptr[tx + 1]; ++i) {
      f = (indices[i - 1] <= indices[i]);
    }
    flags[tx] = static_cast<int8_t>(f);
    tx += stride_x;
  }
}

41
template <DGLDeviceType XPU, typename IdType>
42
43
bool CSRIsSorted(CSRMatrix csr) {
  const auto& ctx = csr.indptr->ctx;
sangwzh's avatar
sangwzh committed
44
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
45
  auto device = runtime::DeviceAPI::Get(ctx);
46
47
48
49
  // We allocate a workspace of num_rows bytes. It wastes a little bit memory
  // but should be fine.
  int8_t* flags =
      static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
50
51
  const int nt = cuda::FindNumThreads(csr.num_rows);
  const int nb = (csr.num_rows + nt - 1) / nt;
52
53
54
  CUDA_KERNEL_CALL(
      _SegmentIsSorted, nb, nt, 0, stream, csr.indptr.Ptr<IdType>(),
      csr.indices.Ptr<IdType>(), csr.num_rows, flags);
55
56
57
58
59
  bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);
  device->FreeWorkspace(ctx, flags);
  return ret;
}

60
61
template bool CSRIsSorted<kDGLCUDA, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCUDA, int64_t>(CSRMatrix csr);
62

63
template <DGLDeviceType XPU, typename IdType>
64
void CSRSort_(CSRMatrix* csr) {
65
66
67
68
  LOG(FATAL) << "Unreachable codes";
}

template <>
69
void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
70
71
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
sangwzh's avatar
sangwzh committed
72
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
73
74
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
sangwzh's avatar
sangwzh committed
75
    CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle)));
76
  }
sangwzh's avatar
sangwzh committed
77
  CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream));
78
79
80
81
82
83
84
85
86
87

  NDArray indptr = csr->indptr;
  NDArray indices = csr->indices;
  const auto& ctx = indptr->ctx;
  const int64_t nnz = indices->shape[0];
  if (!aten::CSRHasData(*csr))
    csr->data = aten::Range(0, nnz, indices->dtype.bits, ctx);
  NDArray data = csr->data;

  size_t workspace_size = 0;
sangwzh's avatar
sangwzh committed
88
  CUSPARSE_CALL(hipsparseXcsrsort_bufferSizeExt(
89
90
      thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), &workspace_size));
91
92
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

sangwzh's avatar
sangwzh committed
93
94
95
96
97
  hipsparseMatDescr_t descr;
  CUSPARSE_CALL(hipsparseCreateMatDescr(&descr));
  CUSPARSE_CALL(hipsparseSetMatType(descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
  CUSPARSE_CALL(hipsparseSetMatIndexBase(descr, HIPSPARSE_INDEX_BASE_ZERO));
  CUSPARSE_CALL(hipsparseXcsrsort(
98
99
      thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz, descr,
      indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), data.Ptr<int32_t>(),
100
101
102
103
104
      workspace));

  csr->sorted = true;

  // free resources
sangwzh's avatar
sangwzh committed
105
  CUSPARSE_CALL(hipsparseDestroyMatDescr(descr));
106
107
108
  device->FreeWorkspace(ctx, workspace);
}

109
template <>
110
void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
sangwzh's avatar
sangwzh committed
111
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
112
113
114
115
116
  auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);

  const auto& ctx = csr->indptr->ctx;
  const int64_t nnz = csr->indices->shape[0];
  const auto nbits = csr->indptr->dtype.bits;
117
  if (!aten::CSRHasData(*csr)) csr->data = aten::Range(0, nnz, nbits, ctx);
118
119
120
121
122
123
124
125
126
127
128
129

  IdArray new_indices = csr->indices.Clone();
  IdArray new_data = csr->data.Clone();

  const int64_t* offsets = csr->indptr.Ptr<int64_t>();
  const int64_t* key_in = csr->indices.Ptr<int64_t>();
  int64_t* key_out = new_indices.Ptr<int64_t>();
  const int64_t* value_in = csr->data.Ptr<int64_t>();
  int64_t* value_out = new_data.Ptr<int64_t>();

  // Allocate workspace
  size_t workspace_size = 0;
sangwzh's avatar
sangwzh committed
130
  CUDA_CALL(hipcub::DeviceSegmentedRadixSort::SortPairs(
131
132
      nullptr, workspace_size, key_in, key_out, value_in, value_out, nnz,
      csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));
133
134
135
  void* workspace = device->AllocWorkspace(ctx, workspace_size);

  // Compute
sangwzh's avatar
sangwzh committed
136
  CUDA_CALL(hipcub::DeviceSegmentedRadixSort::SortPairs(
137
138
      workspace, workspace_size, key_in, key_out, value_in, value_out, nnz,
      csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));
139
140
141
142

  csr->sorted = true;
  csr->indices = new_indices;
  csr->data = new_data;
143
144
145

  // free resources
  device->FreeWorkspace(ctx, workspace);
146
147
}

148
149
template void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr);
150
151
152
153

}  // namespace impl
}  // namespace aten
}  // namespace dgl