coo_sort.cu 5.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/coo_sort.cc
 * \brief Sort COO index
 */
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {

using runtime::NDArray;

namespace aten {
namespace impl {

///////////////////////////// COOSort_ /////////////////////////////

template <DLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) {
  // TODO(minjie): Current implementation is based on cusparse which only supports
  //   int32_t. To support int64_t, we could use the Radix sort algorithm provided
  //   by CUB.
  CHECK(sizeof(IdType) == 4) << "CUDA COOSort does not support int64.";
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(coo->row->ctx);
  // allocate cusparse handle if needed
  if (!thr_entry->cusparse_handle) {
    CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
  }
  CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));


  NDArray row = coo->row;
  NDArray col = coo->col;
  if (!aten::COOHasData(*coo))
    coo->data = aten::Range(0, row->shape[0], row->dtype.bits, row->ctx);
  NDArray data = coo->data;
  int32_t* row_ptr = static_cast<int32_t*>(row->data);
  int32_t* col_ptr = static_cast<int32_t*>(col->data);
  int32_t* data_ptr = static_cast<int32_t*>(data->data);

  // sort row
  size_t workspace_size = 0;
  CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(
      thr_entry->cusparse_handle,
      coo->num_rows, coo->num_cols,
      row->shape[0],
      row_ptr,
      col_ptr,
      &workspace_size));
  void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
  CUSPARSE_CALL(cusparseXcoosortByRow(
      thr_entry->cusparse_handle,
      coo->num_rows, coo->num_cols,
      row->shape[0],
      row_ptr,
      col_ptr,
      data_ptr,
      workspace));
  device->FreeWorkspace(row->ctx, workspace);

  if (sort_column) {
    // First create a row indptr array and then call csrsort
    int32_t* indptr = static_cast<int32_t*>(
        device->AllocWorkspace(row->ctx, (coo->num_rows + 1) * sizeof(IdType)));
    CUSPARSE_CALL(cusparseXcoo2csr(
          thr_entry->cusparse_handle,
          row_ptr,
          row->shape[0],
          coo->num_rows,
          indptr,
          CUSPARSE_INDEX_BASE_ZERO));
    CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt(
          thr_entry->cusparse_handle,
          coo->num_rows,
          coo->num_cols,
          row->shape[0],
          indptr,
          col_ptr,
          &workspace_size));
    void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
    cusparseMatDescr_t descr;
    CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
    CUSPARSE_CALL(cusparseXcsrsort(
          thr_entry->cusparse_handle,
          coo->num_rows,
          coo->num_cols,
          row->shape[0],
          descr,
          indptr,
          col_ptr,
          data_ptr,
          workspace));
    CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
    device->FreeWorkspace(row->ctx, workspace);
    device->FreeWorkspace(row->ctx, indptr);
  }

  coo->row_sorted = true;
  coo->col_sorted = sort_column;
}

template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column);

///////////////////////////// COOIsSorted /////////////////////////////

template <typename IdType>
__global__ void _COOIsSortedKernel(
    const IdType* row, const IdType* col,
    int64_t nnz, int8_t* row_sorted, int8_t* col_sorted) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;
  while (tx < nnz) {
    if (tx == 0) {
      row_sorted[0] = 1;
      col_sorted[0] = 1;
    } else {
      row_sorted[tx] = static_cast<int8_t>(row[tx - 1] <= row[tx]);
      col_sorted[tx] = static_cast<int8_t>(
          row[tx - 1] < row[tx] || col[tx - 1] <= col[tx]);
    }
    tx += stride_x;
  }
}

template <DLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
  const int64_t nnz = coo.row->shape[0];
  const auto& ctx = coo.row->ctx;
  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  auto device = runtime::DeviceAPI::Get(ctx);
  // We allocate a workspace of 2*nnz bytes. It wastes a little bit memory but should
  // be fine.
  int8_t* row_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));
  int8_t* col_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));
  const int nt = cuda::FindNumThreads(nnz);
  const int nb = (nnz + nt - 1) / nt;
  _COOIsSortedKernel<<<nb, nt, 0, thr_entry->stream>>>(
      coo.row.Ptr<IdType>(), coo.col.Ptr<IdType>(),
      nnz, row_flags, col_flags);

  const bool row_sorted = cuda::AllTrue(row_flags, nnz, ctx);
  const bool col_sorted = row_sorted? cuda::AllTrue(col_flags, nnz, ctx) : false;

  device->FreeWorkspace(ctx, row_flags);
  device->FreeWorkspace(ctx, col_flags);

  return {row_sorted, col_sorted};
}

template std::pair<bool, bool> COOIsSorted<kDLGPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLGPU, int64_t>(COOMatrix coo);

}  // namespace impl
}  // namespace aten
}  // namespace dgl