csr_sort.cc 5.55 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/csr_sort.cc
 * @brief CSR sorting
5
6
 */
#include <dgl/array.h>
7
#include <dgl/runtime/parallel_for.h>
8

9
#include <algorithm>
10
#include <numeric>
11
12
13
14
15
16
17
#include <vector>

namespace dgl {
namespace aten {
namespace impl {

///////////////////////////// CSRIsSorted /////////////////////////////
18
template <DGLDeviceType XPU, typename IdType>
19
bool CSRIsSorted(CSRMatrix csr) {
20
21
22
23
24
25
26
27
28
  const IdType *indptr = csr.indptr.Ptr<IdType>();
  const IdType *indices = csr.indices.Ptr<IdType>();
  return runtime::parallel_reduce(
      0, csr.num_rows, 1, 1,
      [indptr, indices](size_t b, size_t e, bool ident) {
        for (size_t row = b; row < e; ++row) {
          for (IdType i = indptr[row] + 1; i < indptr[row + 1]; ++i) {
            if (indices[i - 1] > indices[i]) return false;
          }
29
        }
30
31
32
        return ident;
      },
      [](bool a, bool b) { return a && b; });
33
34
}

35
36
template bool CSRIsSorted<kDGLCPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCPU, int32_t>(CSRMatrix csr);
37
38
39

///////////////////////////// CSRSort /////////////////////////////

40
template <DGLDeviceType XPU, typename IdType>
41
void CSRSort_(CSRMatrix *csr) {
42
43
44
  typedef std::pair<IdType, IdType> ShufflePair;
  const int64_t num_rows = csr->num_rows;
  const int64_t nnz = csr->indices->shape[0];
45
46
  const IdType *indptr_data = static_cast<IdType *>(csr->indptr->data);
  IdType *indices_data = static_cast<IdType *>(csr->indices->data);
47
48
49
50
51
52

  if (CSRIsSorted(*csr)) {
    csr->sorted = true;
    return;
  }

53
54
55
  if (!CSRHasData(*csr)) {
    csr->data = aten::Range(0, nnz, csr->indptr->dtype.bits, csr->indptr->ctx);
  }
56
  IdType *eid_data = static_cast<IdType *>(csr->data->data);
57
58
59

  runtime::parallel_for(0, num_rows, [=](size_t b, size_t e) {
    for (auto row = b; row < e; ++row) {
60
      const int64_t num_cols = indptr_data[row + 1] - indptr_data[row];
61
      std::vector<ShufflePair> reorder_vec(num_cols);
62
63
64
65
66
67
68
      IdType *col = indices_data + indptr_data[row];
      IdType *eid = eid_data + indptr_data[row];

      for (int64_t i = 0; i < num_cols; i++) {
        reorder_vec[i].first = col[i];
        reorder_vec[i].second = eid[i];
      }
69
70
71
72
73
      std::sort(
          reorder_vec.begin(), reorder_vec.end(),
          [](const ShufflePair &e1, const ShufflePair &e2) {
            return e1.first < e2.first;
          });
74
75
76
77
78
      for (int64_t i = 0; i < num_cols; i++) {
        col[i] = reorder_vec[i].first;
        eid[i] = reorder_vec[i].second;
      }
    }
79
80
  });

81
82
83
  csr->sorted = true;
}

84
85
template void CSRSort_<kDGLCPU, int64_t>(CSRMatrix *csr);
template void CSRSort_<kDGLCPU, int32_t>(CSRMatrix *csr);
86

87
template <DGLDeviceType XPU, typename IdType, typename TagType>
88
89
90
91
std::pair<CSRMatrix, NDArray> CSRSortByTag(
    const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) {
  const auto indptr_data = static_cast<const IdType *>(csr.indptr->data);
  const auto indices_data = static_cast<const IdType *>(csr.indices->data);
92
93
94
  const auto eid_data = aten::CSRHasData(csr)
                            ? static_cast<const IdType *>(csr.data->data)
                            : nullptr;
95
96
97
  const auto tag_data = static_cast<const TagType *>(tag_array->data);
  const int64_t num_rows = csr.num_rows;

98
99
  NDArray tag_pos = NDArray::Empty(
      {csr.num_rows, num_tags + 1}, csr.indptr->dtype, csr.indptr->ctx);
100
101
102
  auto tag_pos_data = static_cast<IdType *>(tag_pos->data);
  std::fill(tag_pos_data, tag_pos_data + csr.num_rows * (num_tags + 1), 0);

103
104
105
106
107
  aten::CSRMatrix output(
      csr.num_rows, csr.num_cols, csr.indptr.Clone(), csr.indices.Clone(),
      NDArray::Empty(
          {csr.indices->shape[0]}, csr.indices->dtype, csr.indices->ctx),
      csr.sorted);
108
109
110
111

  auto out_indices_data = static_cast<IdType *>(output.indices->data);
  auto out_eid_data = static_cast<IdType *>(output.data->data);

112
113
114
115
116
117
118
119
  runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {
    for (auto src = b; src < e; ++src) {
      const IdType start = indptr_data[src];
      const IdType end = indptr_data[src + 1];

      auto tag_pos_row = tag_pos_data + src * (num_tags + 1);
      std::vector<IdType> pointer(num_tags, 0);

120
      for (IdType ptr = start; ptr < end; ++ptr) {
121
122
        const IdType eid = eid_data ? eid_data[ptr] : ptr;
        const TagType tag = tag_data[eid];
123
124
125
126
        CHECK_LT(tag, num_tags);
        ++tag_pos_row[tag + 1];
      }  // count

127
      for (TagType tag = 1; tag <= num_tags; ++tag) {
128
129
130
        tag_pos_row[tag] += tag_pos_row[tag - 1];
      }  // cumulate

131
      for (IdType ptr = start; ptr < end; ++ptr) {
132
        const IdType dst = indices_data[ptr];
133
134
        const IdType eid = eid_data ? eid_data[ptr] : ptr;
        const TagType tag = tag_data[eid];
135
136
137
138
139
140
141
        const IdType offset = tag_pos_row[tag] + pointer[tag];
        CHECK_LT(offset, tag_pos_row[tag + 1]);
        ++pointer[tag];

        out_indices_data[start + offset] = dst;
        out_eid_data[start + offset] = eid;
      }
142
    }
143
  });
144
145
146
147
  output.sorted = false;
  return std::make_pair(output, tag_pos);
}

148
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int64_t>(
149
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
150
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int32_t>(
151
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
152
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int64_t>(
153
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
154
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int32_t>(
155
156
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);

157
158
159
}  // namespace impl
}  // namespace aten
}  // namespace dgl