csr_sort.cc 5.63 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/csr_sort.cc
 * \brief CSR sorting
 */
#include <dgl/array.h>
7
#include <dgl/runtime/parallel_for.h>
8
9
10
11
12
13
14
15
16
#include <numeric>
#include <algorithm>
#include <vector>

namespace dgl {
namespace aten {
namespace impl {

///////////////////////////// CSRIsSorted /////////////////////////////
17
template <DGLDeviceType XPU, typename IdType>
18
19
20
bool CSRIsSorted(CSRMatrix csr) {
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
21
22
23
24
25
26
27
  return runtime::parallel_reduce(0, csr.num_rows, 1, 1,
    [indptr, indices](size_t b, size_t e, bool ident) {
      for (size_t row = b; row < e; ++row) {
        for (IdType i = indptr[row] + 1; i < indptr[row + 1]; ++i) {
          if (indices[i - 1] > indices[i])
            return false;
        }
28
      }
29
30
31
      return ident;
    },
    [](bool a, bool b) { return a && b; });
32
33
}

34
35
template bool CSRIsSorted<kDGLCPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCPU, int32_t>(CSRMatrix csr);
36
37
38

///////////////////////////// CSRSort /////////////////////////////

39
template <DGLDeviceType XPU, typename IdType>
40
41
42
43
44
45
void CSRSort_(CSRMatrix* csr) {
  typedef std::pair<IdType, IdType> ShufflePair;
  const int64_t num_rows = csr->num_rows;
  const int64_t nnz = csr->indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr->indptr->data);
  IdType* indices_data = static_cast<IdType*>(csr->indices->data);
46
47
48
49
50
51

  if (CSRIsSorted(*csr)) {
    csr->sorted = true;
    return;
  }

52
53
54
55
  if (!CSRHasData(*csr)) {
    csr->data = aten::Range(0, nnz, csr->indptr->dtype.bits, csr->indptr->ctx);
  }
  IdType* eid_data = static_cast<IdType*>(csr->data->data);
56
57
58

  runtime::parallel_for(0, num_rows, [=](size_t b, size_t e) {
    for (auto row = b; row < e; ++row) {
59
      const int64_t num_cols = indptr_data[row + 1] - indptr_data[row];
60
      std::vector<ShufflePair> reorder_vec(num_cols);
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
      IdType *col = indices_data + indptr_data[row];
      IdType *eid = eid_data + indptr_data[row];

      for (int64_t i = 0; i < num_cols; i++) {
        reorder_vec[i].first = col[i];
        reorder_vec[i].second = eid[i];
      }
      std::sort(reorder_vec.begin(), reorder_vec.end(),
                [](const ShufflePair &e1, const ShufflePair &e2) {
                  return e1.first < e2.first;
                });
      for (int64_t i = 0; i < num_cols; i++) {
        col[i] = reorder_vec[i].first;
        eid[i] = reorder_vec[i].second;
      }
    }
77
78
  });

79
80
81
  csr->sorted = true;
}

82
83
template void CSRSort_<kDGLCPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCPU, int32_t>(CSRMatrix* csr);
84

85
template <DGLDeviceType XPU, typename IdType, typename TagType>
86
87
88
89
std::pair<CSRMatrix, NDArray> CSRSortByTag(
    const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) {
  const auto indptr_data = static_cast<const IdType *>(csr.indptr->data);
  const auto indices_data = static_cast<const IdType *>(csr.indices->data);
90
91
92
  const auto eid_data = aten::CSRHasData(csr)
                            ? static_cast<const IdType *>(csr.data->data)
                            : nullptr;
93
94
95
96
97
98
99
100
  const auto tag_data = static_cast<const TagType *>(tag_array->data);
  const int64_t num_rows = csr.num_rows;

  NDArray tag_pos = NDArray::Empty({csr.num_rows, num_tags + 1},
      csr.indptr->dtype, csr.indptr->ctx);
  auto tag_pos_data = static_cast<IdType *>(tag_pos->data);
  std::fill(tag_pos_data, tag_pos_data + csr.num_rows * (num_tags + 1), 0);

101
102
103
104
105
  aten::CSRMatrix output(csr.num_rows, csr.num_cols, csr.indptr.Clone(),
                         csr.indices.Clone(),
                         NDArray::Empty({csr.indices->shape[0]},
                                        csr.indices->dtype, csr.indices->ctx),
                         csr.sorted);
106
107
108
109

  auto out_indices_data = static_cast<IdType *>(output.indices->data);
  auto out_eid_data = static_cast<IdType *>(output.data->data);

110
111
112
113
114
115
116
117
118
  runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {
    for (auto src = b; src < e; ++src) {
      const IdType start = indptr_data[src];
      const IdType end = indptr_data[src + 1];

      auto tag_pos_row = tag_pos_data + src * (num_tags + 1);
      std::vector<IdType> pointer(num_tags, 0);

      for (IdType ptr = start ; ptr < end ; ++ptr) {
119
120
        const IdType eid = eid_data ? eid_data[ptr] : ptr;
        const TagType tag = tag_data[eid];
121
122
123
124
125
126
127
128
129
130
        CHECK_LT(tag, num_tags);
        ++tag_pos_row[tag + 1];
      }  // count

      for (TagType tag = 1 ; tag <= num_tags; ++tag) {
        tag_pos_row[tag] += tag_pos_row[tag - 1];
      }  // cumulate

      for (IdType ptr = start ; ptr < end ; ++ptr) {
        const IdType dst = indices_data[ptr];
131
132
        const IdType eid = eid_data ? eid_data[ptr] : ptr;
        const TagType tag = tag_data[eid];
133
134
135
136
137
138
139
        const IdType offset = tag_pos_row[tag] + pointer[tag];
        CHECK_LT(offset, tag_pos_row[tag + 1]);
        ++pointer[tag];

        out_indices_data[start + offset] = dst;
        out_eid_data[start + offset] = eid;
      }
140
    }
141
  });
142
143
144
145
  output.sorted = false;
  return std::make_pair(output, tag_pos);
}

146
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int64_t>(
147
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
148
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int32_t>(
149
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
150
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int64_t>(
151
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
152
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int32_t>(
153
154
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);

155
156
157
}  // namespace impl
}  // namespace aten
}  // namespace dgl